feat: 真稀疏MoE(capacity分组,只算top-k,cutlass baddbmm,无host同步)

按expert排序token+固定capacity分桶,每桶dense baddbmm,减GEMM~3x。argsort/where/
scatter/index_add无.item()/bincount同步(不同于loop MoE)。超容量token丢弃(capacity_factor控)。
等价测试(大capacity无丢弃==dense)。bench --moe-sparse/--moe-cap。默认关待验证。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-17 21:05:55 +08:00
parent aacfe904fd
commit b397c142fa
3 changed files with 64 additions and 0 deletions
+20
View File
@@ -192,6 +192,25 @@ def test_varlen_matches_dense_attention():
print(f"[PASS] varlen(嵌套张量) == 稠密SDPA (max err={err:.2e})")
def test_sparse_moe_matches_dense():
# 大 capacity(无丢弃)下,稀疏 MoE 应与 dense 数学等价
torch.manual_seed(0)
dev = "cuda" if torch.cuda.is_available() else "cpu"
m = infer.SMoE(d_model=512, dim_ff=1024, num_experts=8, k=2).to(dev).eval()
x = torch.randn(1, 200, 512, device=dev)
with torch.no_grad():
infer.CONFIG["moe_sparse"] = False
ref, _ = m(x)
infer.CONFIG["moe_sparse"] = True
infer.CONFIG["moe_capacity"] = 8.0 # 足够大,不丢 token
new, _ = m(x)
infer.CONFIG["moe_sparse"] = False
infer.CONFIG["moe_capacity"] = 1.25
err = (ref - new).abs().max().item()
assert torch.allclose(ref, new, atol=1e-3, rtol=1e-3), f"sparse MoE 不等价 max err={err:.3e}"
print(f"[PASS] sparse MoE(大capacity) == dense (max err={err:.2e}, dev={dev})")
def test_fused_embedding_matches_perslot():
torch.manual_seed(0)
dev = "cuda" if torch.cuda.is_available() else "cpu"
@@ -240,6 +259,7 @@ def test_flex_matches_dense_attention():
if __name__ == "__main__":
test_moe_dense_matches_loop()
test_sparse_moe_matches_dense()
test_fused_embedding_matches_perslot()
test_embedding_bag_matches()
test_sparse_pool_matches()