feat: F.embedding_bag 融合查表+池化(单kernel,免[M,512]中间) — 攻最大块(dedup index25%+segment11%=36%)

triton版profile:attention已优化出top,新大头=embedding池化36%+MoE22%+add18%。
embedding_bag一个kernel做查表+按段求和。等价测试+bench --emb-bag。默认关待验证。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-17 13:30:47 +08:00
parent 1083aca9fa
commit 74bb95a7bd
3 changed files with 35 additions and 1 deletions
+25
View File
@@ -86,6 +86,30 @@ def test_chunked_matches_dense_attention():
print(f"[PASS] chunked SDPA == 稠密SDPA (max err={err:.2e}, dev={dev})")
def test_embedding_bag_matches():
torch.manual_seed(0)
dev = "cuda" if torch.cuda.is_available() else "cpu"
slot_num, emb_dim, d_model = 28, 512, 512
enc = infer.RepEncoder(vocab_size=200, emb_dim=emb_dim, slot_num=slot_num,
d_model=d_model).to(dev).eval()
N = 6
batch = {}
for s in range(1, slot_num + 1):
counts = torch.randint(0, 8, (N,))
vals = torch.randint(0, 200, (int(counts.sum()),), device=dev)
offs = torch.cat([torch.zeros(1, dtype=torch.long), counts.cumsum(0)]).to(dev)
batch[s] = (vals, offs)
with torch.no_grad():
infer.CONFIG["use_embedding_bag"] = False
ref = enc(batch)
infer.CONFIG["use_embedding_bag"] = True
new = enc(batch)
infer.CONFIG["use_embedding_bag"] = False
err = (ref - new).abs().max().item()
assert torch.allclose(ref, new, atol=1e-3, rtol=1e-3), f"embedding_bag 不等价 max err={err:.3e}"
print(f"[PASS] embedding_bag == segment_reduce (max err={err:.2e}, dev={dev})")
def test_sparse_pool_matches():
torch.manual_seed(0)
dev = "cuda" if torch.cuda.is_available() else "cpu"
@@ -217,6 +241,7 @@ def test_flex_matches_dense_attention():
if __name__ == "__main__":
test_moe_dense_matches_loop()
test_fused_embedding_matches_perslot()
test_embedding_bag_matches()
test_sparse_pool_matches()
test_syncfree_mask_matches()
test_triton_varlen_matches_dense()