From 66256660102276ef7afc45a66266d5b5349a0022 Mon Sep 17 00:00:00 2001 From: OwnerSunshine530 Date: Mon, 15 Jun 2026 15:15:13 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20sparse=5Fpool=20=E9=80=89=E9=A1=B9=20?= =?UTF-8?q?=E2=80=94=20(=E6=AE=B5=C3=97=E5=94=AF=E4=B8=80)=E7=A8=80?= =?UTF-8?q?=E7=96=8F=E7=9F=A9=E9=98=B5=E4=B9=98=E5=81=9A=E6=B1=A0=E5=8C=96?= =?UTF-8?q?,=E9=81=BF=E5=85=8Dmaterialize[M,emb]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 针对 profile 的 dedup展开(15%)+segment_reduce(6.6%)。段内高重复(slot19)塌缩 为单个带权项。CONFIG.sparse_pool;bench --sparse-pool;等价测试已加。默认关,待验证。 Co-Authored-By: Claude Opus 4.8 --- 代码/code/bench.py | 3 +++ 代码/code/infer.py | 25 ++++++++++++++++++++----- 代码/code/tests/test_equiv.py | 27 +++++++++++++++++++++++++++ 3 files changed, 50 insertions(+), 5 deletions(-) diff --git a/代码/code/bench.py b/代码/code/bench.py index ea890b1..2d0c056 100644 --- a/代码/code/bench.py +++ b/代码/code/bench.py @@ -299,6 +299,7 @@ def _parse_args(): ap.add_argument("--compile", action="store_true", help="开启 torch.compile") ap.add_argument("--emb-fp16", action="store_true", help="Embedding表转FP16(查表带宽减半,测AUC)") ap.add_argument("--dedup-emb", action="store_true", help="查表前对sign去重(减少大表随机访存)") + ap.add_argument("--sparse-pool", action="store_true", help="稀疏矩阵乘做池化(段内高重复时省)") ap.add_argument("--profile", type=int, default=None, metavar="N", help="剖析前 N 个 batch,打印按 CUDA 耗时排序的算子表(定位瓶颈)") ap.add_argument("--rebuild", action="store_true", help="强制重建过滤缓存") @@ -334,6 +335,8 @@ if __name__ == "__main__": cfg["emb_fp16"] = True if a.dedup_emb: cfg["dedup_embedding"] = True + if a.sparse_pool: + cfg["sparse_pool"] = True if a.compile: cfg["compile"] = True if a.profile is not None: diff --git a/代码/code/infer.py b/代码/code/infer.py index 339e765..cce23e0 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -53,6 +53,7 @@ CONFIG = { "syncfree_mask": True, # True=用searchsorted构造因果mask(无同步);False=repeat_interleave(同步) "emb_fp16": True, # True=Embedding表转FP16(查表带宽减半,实测AUC 0.75932≈无损) "dedup_embedding": True, # True=查表前对sign去重(只查唯一值再展开),本地7.80->6.49s,AUC逐位等价 + "sparse_pool": False, # True=用(段×唯一)稀疏矩阵乘做池化,避免materialize整个[M,512](段内高重复时省) "compile": False, # 是否 torch.compile(实测慢5×,勿开) } @@ -381,13 +382,27 @@ class RepEncoder(nn.Module): cat_values = self._signid(torch.cat(parts), max_idx) seg = torch.cat([torch.zeros(1, dtype=torch.long, device=cat_values.device), torch.cat(ends)]) # [28*N + 1] - if CONFIG.get("dedup_embedding", False): - # 去重:只对唯一 sign 查大表,再按逆索引展开(数学逐位等价,省随机访存) + if CONFIG.get("sparse_pool", False): + # 稀疏池化:pooled = W @ emb_unique,W[段,唯一]=该段内该唯一sign出现次数。 + # 段内高重复(slot19)塌缩成单个带权项,避免 materialize 整个 [M,emb]。 uniq, inv = torch.unique(cat_values, return_inverse=True) - emb = self.emb(uniq).to(target_dtype)[inv] + emb_unique = self.emb(uniq).float() # 小表;sparse.mm 用 fp32 稳 + M = cat_values.numel() + num_seg = seg.numel() - 1 + seg_id = torch.searchsorted( + seg, torch.arange(M, device=cat_values.device), right=True) - 1 + W = torch.sparse_coo_tensor( + torch.stack([seg_id, inv]), + torch.ones(M, device=cat_values.device, dtype=torch.float32), + size=(num_seg, uniq.numel())).coalesce() + pooled = torch.sparse.mm(W, emb_unique).to(target_dtype) # [28*N, emb] else: - emb = self.emb(cat_values).to(target_dtype) - pooled = torch.segment_reduce(emb, reduce='sum', offsets=seg, initial=0) # [28*N, emb] + if CONFIG.get("dedup_embedding", False): + uniq, inv = torch.unique(cat_values, return_inverse=True) + emb = self.emb(uniq).to(target_dtype)[inv] + else: + emb = self.emb(cat_values).to(target_dtype) + pooled = torch.segment_reduce(emb, reduce='sum', offsets=seg, initial=0) # [28*N, emb] pooled = pooled.view(self.slot_num, N, self.emb_dim).permute(1, 0, 2).reshape( N, self.slot_num * self.emb_dim) return self.linear(self.input_norm(pooled)) diff --git a/代码/code/tests/test_equiv.py b/代码/code/tests/test_equiv.py index 4ecd922..1e7b7af 100644 --- a/代码/code/tests/test_equiv.py +++ b/代码/code/tests/test_equiv.py @@ -86,6 +86,32 @@ def test_chunked_matches_dense_attention(): print(f"[PASS] chunked SDPA == 稠密SDPA (max err={err:.2e}, dev={dev})") +def test_sparse_pool_matches(): + torch.manual_seed(0) + dev = "cuda" if torch.cuda.is_available() else "cpu" + slot_num, emb_dim, d_model = 28, 512, 512 + enc = infer.RepEncoder(vocab_size=200, emb_dim=emb_dim, slot_num=slot_num, + d_model=d_model).to(dev).eval() + N = 6 + batch = {} + for s in range(1, slot_num + 1): + counts = torch.randint(0, 8, (N,)) + # 故意制造段内重复:值域很小,重复率高 + vals = torch.randint(0, 30, (int(counts.sum()),), device=dev) + offs = torch.cat([torch.zeros(1, dtype=torch.long), counts.cumsum(0)]).to(dev) + batch[s] = (vals, offs) + with torch.no_grad(): + infer.CONFIG["sparse_pool"] = False + infer.CONFIG["dedup_embedding"] = True + ref = enc(batch) + infer.CONFIG["sparse_pool"] = True + new = enc(batch) + infer.CONFIG["sparse_pool"] = False + err = (ref - new).abs().max().item() + assert torch.allclose(ref, new, atol=2e-2, rtol=2e-2), f"sparse_pool 不等价 max err={err:.3e}" + print(f"[PASS] sparse_pool == segment_reduce (max err={err:.2e}, dev={dev})") + + def test_syncfree_mask_matches(): dev = "cuda" if torch.cuda.is_available() else "cpu" rep = infer.RepEncoder(vocab_size=100, emb_dim=8, slot_num=28, d_model=8) @@ -169,6 +195,7 @@ def test_flex_matches_dense_attention(): if __name__ == "__main__": test_moe_dense_matches_loop() test_fused_embedding_matches_perslot() + test_sparse_pool_matches() test_syncfree_mask_matches() test_chunked_matches_dense_attention() test_varlen_matches_dense_attention()