feat: F.embedding_bag 融合查表+池化(单kernel,免[M,512]中间) — 攻最大块(dedup index25%+segment11%=36%)

triton版profile:attention已优化出top,新大头=embedding池化36%+MoE22%+add18%。
embedding_bag一个kernel做查表+按段求和。等价测试+bench --emb-bag。默认关待验证。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-17 13:30:47 +08:00
parent 1083aca9fa
commit 74bb95a7bd
3 changed files with 35 additions and 1 deletions
+3
View File
@@ -324,6 +324,7 @@ def _parse_args():
ap.add_argument("--compile", action="store_true", help="开启 torch.compile")
ap.add_argument("--emb-fp16", action="store_true", help="Embedding表转FP16(查表带宽减半,测AUC)")
ap.add_argument("--dedup-emb", action="store_true", help="查表前对sign去重(减少大表随机访存)")
ap.add_argument("--emb-bag", action="store_true", help="F.embedding_bag 融合查表+池化")
ap.add_argument("--sparse-pool", action="store_true", help="稀疏矩阵乘做池化(段内高重复时省)")
ap.add_argument("--precompute-rep", action="store_true",
help="预计算RepEncoder缓存,model(batch)跳过embedding层(从batches自建)")
@@ -370,6 +371,8 @@ if __name__ == "__main__":
cfg["emb_fp16"] = True
if a.dedup_emb:
cfg["dedup_embedding"] = True
if a.emb_bag:
cfg["use_embedding_bag"] = True
if a.sparse_pool:
cfg["sparse_pool"] = True
if a.precompute_rep: