diff --git a/代码/code/bench.py b/代码/code/bench.py index 0c4cb4f..27308dd 100644 --- a/代码/code/bench.py +++ b/代码/code/bench.py @@ -324,6 +324,7 @@ def _parse_args(): ap.add_argument("--compile", action="store_true", help="开启 torch.compile") ap.add_argument("--emb-fp16", action="store_true", help="Embedding表转FP16(查表带宽减半,测AUC)") ap.add_argument("--dedup-emb", action="store_true", help="查表前对sign去重(减少大表随机访存)") + ap.add_argument("--emb-bag", action="store_true", help="F.embedding_bag 融合查表+池化") ap.add_argument("--sparse-pool", action="store_true", help="稀疏矩阵乘做池化(段内高重复时省)") ap.add_argument("--precompute-rep", action="store_true", help="预计算RepEncoder缓存,model(batch)跳过embedding层(从batches自建)") @@ -370,6 +371,8 @@ if __name__ == "__main__": cfg["emb_fp16"] = True if a.dedup_emb: cfg["dedup_embedding"] = True + if a.emb_bag: + cfg["use_embedding_bag"] = True if a.sparse_pool: cfg["sparse_pool"] = True if a.precompute_rep: diff --git a/代码/code/infer.py b/代码/code/infer.py index 5854344..f66540f 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -145,6 +145,7 @@ CONFIG = { "fuse_embedding": True, # True=28个slot的查表+池化融合为1次(减per-batch kernel启动) "syncfree_mask": True, # True=用searchsorted构造因果mask(无同步);False=repeat_interleave(同步) "emb_fp16": True, # True=Embedding表转FP16(查表带宽减半,实测AUC 0.75932≈无损) + "use_embedding_bag": False, # True=用 F.embedding_bag 融合查表+池化(单kernel,免[M,512]中间),攻最大块 "dedup_embedding": True, # True=查表前对sign去重(只查唯一值再展开),本地7.80->6.49s,AUC逐位等价 "sparse_pool": False, # True=用(段×唯一)稀疏矩阵乘做池化,避免materialize整个[M,512](段内高重复时省) "compile": False, # 是否 torch.compile(实测慢5×,勿开) @@ -520,7 +521,12 @@ class RepEncoder(nn.Module): cat_values = self._signid(torch.cat(parts), max_idx) seg = torch.cat([torch.zeros(1, dtype=torch.long, device=cat_values.device), torch.cat(ends)]) # [28*N + 1] - if CONFIG.get("sparse_pool", False): + if CONFIG.get("use_embedding_bag", False): + # F.embedding_bag 融合"查表+按段求和",单 kernel,免 [M,emb] 中间。 + pooled = F.embedding_bag( + cat_values, self.emb.weight, + offsets=seg[:-1].contiguous(), mode="sum").to(target_dtype) + elif CONFIG.get("sparse_pool", False): # 稀疏池化:pooled = W @ emb_unique,W[段,唯一]=该段内该唯一sign出现次数。 # 段内高重复(slot19)塌缩成单个带权项,避免 materialize 整个 [M,emb]。 uniq, inv = torch.unique(cat_values, return_inverse=True) diff --git a/代码/code/tests/test_equiv.py b/代码/code/tests/test_equiv.py index a4fa482..49a81b8 100644 --- a/代码/code/tests/test_equiv.py +++ b/代码/code/tests/test_equiv.py @@ -86,6 +86,30 @@ def test_chunked_matches_dense_attention(): print(f"[PASS] chunked SDPA == 稠密SDPA (max err={err:.2e}, dev={dev})") +def test_embedding_bag_matches(): + torch.manual_seed(0) + dev = "cuda" if torch.cuda.is_available() else "cpu" + slot_num, emb_dim, d_model = 28, 512, 512 + enc = infer.RepEncoder(vocab_size=200, emb_dim=emb_dim, slot_num=slot_num, + d_model=d_model).to(dev).eval() + N = 6 + batch = {} + for s in range(1, slot_num + 1): + counts = torch.randint(0, 8, (N,)) + vals = torch.randint(0, 200, (int(counts.sum()),), device=dev) + offs = torch.cat([torch.zeros(1, dtype=torch.long), counts.cumsum(0)]).to(dev) + batch[s] = (vals, offs) + with torch.no_grad(): + infer.CONFIG["use_embedding_bag"] = False + ref = enc(batch) + infer.CONFIG["use_embedding_bag"] = True + new = enc(batch) + infer.CONFIG["use_embedding_bag"] = False + err = (ref - new).abs().max().item() + assert torch.allclose(ref, new, atol=1e-3, rtol=1e-3), f"embedding_bag 不等价 max err={err:.3e}" + print(f"[PASS] embedding_bag == segment_reduce (max err={err:.2e}, dev={dev})") + + def test_sparse_pool_matches(): torch.manual_seed(0) dev = "cuda" if torch.cuda.is_available() else "cpu" @@ -217,6 +241,7 @@ def test_flex_matches_dense_attention(): if __name__ == "__main__": test_moe_dense_matches_loop() test_fused_embedding_matches_perslot() + test_embedding_bag_matches() test_sparse_pool_matches() test_syncfree_mask_matches() test_triton_varlen_matches_dense()