From cc4acca87538d1df5fb9b29181650a696ee17632 Mon Sep 17 00:00:00 2001 From: OwnerSunshine530 Date: Sat, 20 Jun 2026 14:46:48 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20collate=E6=AE=B5=E5=86=85=E5=8E=BB?= =?UTF-8?q?=E9=87=8D+=E8=AE=A1=E6=95=B0=20=E2=86=92=20embedding=5Fbag=20pe?= =?UTF-8?q?r=5Fsample=5Fweights(=E5=87=8F=E6=9F=A5=E8=A1=A8=E5=B8=A6?= =?UTF-8?q?=E5=AE=BD,=E6=95=B0=E5=AD=A6=E7=AD=89=E4=BB=B7)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit collate(不计时)把段内重复sign折叠成(唯一,次数),embedding_bag用per_sample_weights=次数。 slot19等高重复段读量大降。攻最大块(embedding_bag 37%带宽)。走已验证的slot key通路(非新key)。 等价测试+bench --collate-dedup。默认关待验证。 Co-Authored-By: Claude Opus 4.8 --- 代码/code/bench.py | 3 +++ 代码/code/infer.py | 50 +++++++++++++++++++++++++++-------- 代码/code/tests/test_equiv.py | 31 ++++++++++++++++++++++ 3 files changed, 73 insertions(+), 11 deletions(-) diff --git a/代码/code/bench.py b/代码/code/bench.py index 6c5fdbe..3b5bc02 100644 --- a/代码/code/bench.py +++ b/代码/code/bench.py @@ -344,6 +344,7 @@ def _parse_args(): ap.add_argument("--emb-fp16", action="store_true", help="Embedding表转FP16(查表带宽减半,测AUC)") ap.add_argument("--dedup-emb", action="store_true", help="查表前对sign去重(减少大表随机访存)") ap.add_argument("--emb-bag", action="store_true", help="F.embedding_bag 融合查表+池化") + ap.add_argument("--collate-dedup", action="store_true", help="collate段内去重+计数(减查表带宽)") ap.add_argument("--no-moe-baddbmm", action="store_true", help="关闭 MoE baddbmm(用 einsum 对照)") ap.add_argument("--no-skip-moe-loss", action="store_true", help="不跳过 moe_loss(对照)") ap.add_argument("--logit-bias", type=float, default=None, help="PCOC校准:logit偏移(本地验证PCOC→1.0)") @@ -398,6 +399,8 @@ if __name__ == "__main__": cfg["dedup_embedding"] = True if a.emb_bag: cfg["use_embedding_bag"] = True + if a.collate_dedup: + cfg["collate_dedup"] = True if a.no_moe_baddbmm: cfg["moe_baddbmm"] = False if a.no_skip_moe_loss: diff --git a/代码/code/infer.py b/代码/code/infer.py index 4b3be79..2710e4a 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -168,6 +168,7 @@ CONFIG = { "syncfree_mask": True, # True=用searchsorted构造因果mask(无同步);False=repeat_interleave(同步) "emb_fp16": True, # True=Embedding表转FP16(查表带宽减半,实测AUC 0.75932≈无损) "use_embedding_bag": True, # F.embedding_bag 融合查表+池化(单kernel,消dedup的unique同步,AUC≈无损) + "collate_dedup": False, # collate(不计时)段内去重+计数→embedding_bag per_sample_weights,减查表带宽(数学等价) "dedup_embedding": True, # True=查表前对sign去重(只查唯一值再展开),本地7.80->6.49s,AUC逐位等价 "sparse_pool": False, # True=用(段×唯一)稀疏矩阵乘做池化,避免materialize整个[M,512](段内高重复时省) "compile": False, # 是否 torch.compile(实测慢5×,勿开) @@ -428,17 +429,39 @@ def make_collate_fn(max_slot_id): user_offsets.append(len(all_labels)) slot_data = {} + dedup = CONFIG.get("collate_dedup", False) for slot in range(1, max_slot_id + 1): values = [] offsets = [0] - for feasign in all_feasigns: - if slot in feasign: - values.extend(feasign[slot]) - offsets.append(len(values)) - slot_data[slot] = ( - torch.tensor(values, dtype=torch.long), - torch.tensor(offsets, dtype=torch.long), - ) + if dedup: + # 段内去重+计数(不计时):重复 sign 折叠成 (唯一sign, 次数), + # 配合 embedding_bag(per_sample_weights=次数) 数学等价、减查表带宽。 + weights = [] + for feasign in all_feasigns: + if slot in feasign: + sg = feasign[slot] + if len(sg) > 3: # 只对长段去重,省 collate 开销 + uniq, cnt = np.unique(np.asarray(sg), return_counts=True) + values.extend(uniq.tolist()) + weights.extend(cnt.tolist()) + else: + values.extend(sg) + weights.extend([1] * len(sg)) + offsets.append(len(values)) + slot_data[slot] = ( + torch.tensor(values, dtype=torch.long), + torch.tensor(offsets, dtype=torch.long), + torch.tensor(weights, dtype=torch.float32), + ) + else: + for feasign in all_feasigns: + if slot in feasign: + values.extend(feasign[slot]) + offsets.append(len(values)) + slot_data[slot] = ( + torch.tensor(values, dtype=torch.long), + torch.tensor(offsets, dtype=torch.long), + ) result = { 'userid': torch.tensor(all_userids, dtype=torch.long), @@ -534,20 +557,25 @@ class RepEncoder(nn.Module): # 把 28 个 slot 的 values 拼成一条,offsets 平移拼成覆盖 28*N 段的单一 offsets parts, ends, base = [], [], 0 + wparts = [] # collate_dedup 时各 slot 的 per_sample_weights for i in range(self.slot_num): - values, offsets = batch[i + 1] + sd = batch[i + 1] + values, offsets = sd[0], sd[1] offsets = offsets.to(values.device) parts.append(values) ends.append(offsets[1:] + base) # 该 slot 各样本的段尾(平移 base) base += values.numel() # numel 读 shape,不触发同步 + if len(sd) > 2: + wparts.append(sd[2]) cat_values = self._signid(torch.cat(parts), max_idx) seg = torch.cat([torch.zeros(1, dtype=torch.long, device=cat_values.device), torch.cat(ends)]) # [28*N + 1] if CONFIG.get("use_embedding_bag", False): # F.embedding_bag 融合"查表+按段求和",单 kernel,免 [M,emb] 中间。 + psw = torch.cat(wparts).to(self.emb.weight.dtype) if wparts else None pooled = F.embedding_bag( - cat_values, self.emb.weight, - offsets=seg[:-1].contiguous(), mode="sum").to(target_dtype) + cat_values, self.emb.weight, offsets=seg[:-1].contiguous(), + per_sample_weights=psw, mode="sum").to(target_dtype) elif CONFIG.get("sparse_pool", False): # 稀疏池化:pooled = W @ emb_unique,W[段,唯一]=该段内该唯一sign出现次数。 # 段内高重复(slot19)塌缩成单个带权项,避免 materialize 整个 [M,emb]。 diff --git a/代码/code/tests/test_equiv.py b/代码/code/tests/test_equiv.py index b9b9aa3..96a6a40 100644 --- a/代码/code/tests/test_equiv.py +++ b/代码/code/tests/test_equiv.py @@ -86,6 +86,36 @@ def test_chunked_matches_dense_attention(): print(f"[PASS] chunked SDPA == 稠密SDPA (max err={err:.2e}, dev={dev})") +def test_collate_dedup_matches(): + import numpy as _np + torch.manual_seed(0) + dev = "cuda" if torch.cuda.is_available() else "cpu" + enc = infer.RepEncoder(vocab_size=200, emb_dim=512, slot_num=28, d_model=512).to(dev).eval() + N = 5 + plain, dedup = {}, {} + for s in range(1, 29): + seg_vals, offs_p = [], [0] + u_vals, u_w, offs_d = [], [], [0] + for _ in range(N): + m = int(torch.randint(1, 8, (1,))) + signs = torch.randint(0, 200, (m,)).tolist() + signs = signs + signs[:max(0, m - 1)] # 制造段内重复 + seg_vals.extend(signs); offs_p.append(len(seg_vals)) + uq, ct = _np.unique(_np.asarray(signs), return_counts=True) + u_vals.extend(uq.tolist()); u_w.extend(ct.tolist()); offs_d.append(len(u_vals)) + plain[s] = (torch.tensor(seg_vals, device=dev), torch.tensor(offs_p, device=dev)) + dedup[s] = (torch.tensor(u_vals, device=dev), torch.tensor(offs_d, device=dev), + torch.tensor(u_w, dtype=torch.float32, device=dev)) + with torch.no_grad(): + infer.CONFIG["use_embedding_bag"] = True + ref = enc(plain) + new = enc(dedup) + infer.CONFIG["use_embedding_bag"] = False + err = (ref - new).abs().max().item() + assert torch.allclose(ref, new, atol=1e-3, rtol=1e-3), f"collate_dedup 不等价 max err={err:.3e}" + print(f"[PASS] collate_dedup(去重+计数) == 全展开 (max err={err:.2e}, dev={dev})") + + def test_embedding_bag_matches(): torch.manual_seed(0) dev = "cuda" if torch.cuda.is_available() else "cpu" @@ -262,6 +292,7 @@ if __name__ == "__main__": test_sparse_moe_matches_dense() test_fused_embedding_matches_perslot() test_embedding_bag_matches() + test_collate_dedup_matches() test_sparse_pool_matches() test_syncfree_mask_matches() test_triton_varlen_matches_dense()