feat: fused MoE — baddbmm(cutlass GEMM+bias融合)+跳过推理无用的moe_loss,减kernel

GEMM保留cutlass(triton GEMM难超),融bias epilogue省add kernel;moe_loss仅训练用,
推理跳过省importance/std/mean。延续减kernel方向(embedding_bag/triton已证评测赚)。
默认开,bench --no-moe-baddbmm/--no-skip-moe-loss 对照。AUC无损。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-17 14:27:59 +08:00
parent 6bb51a1057
commit 575b32f263
2 changed files with 27 additions and 6 deletions
+6
View File
@@ -325,6 +325,8 @@ def _parse_args():
ap.add_argument("--emb-fp16", action="store_true", help="Embedding表转FP16(查表带宽减半,测AUC)")
ap.add_argument("--dedup-emb", action="store_true", help="查表前对sign去重(减少大表随机访存)")
ap.add_argument("--emb-bag", action="store_true", help="F.embedding_bag 融合查表+池化")
ap.add_argument("--no-moe-baddbmm", action="store_true", help="关闭 MoE baddbmm(用 einsum 对照)")
ap.add_argument("--no-skip-moe-loss", action="store_true", help="不跳过 moe_loss(对照)")
ap.add_argument("--sparse-pool", action="store_true", help="稀疏矩阵乘做池化(段内高重复时省)")
ap.add_argument("--precompute-rep", action="store_true",
help="预计算RepEncoder缓存,model(batch)跳过embedding层(从batches自建)")
@@ -373,6 +375,10 @@ if __name__ == "__main__":
cfg["dedup_embedding"] = True
if a.emb_bag:
cfg["use_embedding_bag"] = True
if a.no_moe_baddbmm:
cfg["moe_baddbmm"] = False
if a.no_skip_moe_loss:
cfg["skip_moe_loss"] = False
if a.sparse_pool:
cfg["sparse_pool"] = True
if a.precompute_rep: