feat: 真稀疏MoE(capacity分组,只算top-k,cutlass baddbmm,无host同步)
按expert排序token+固定capacity分桶,每桶dense baddbmm,减GEMM~3x。argsort/where/ scatter/index_add无.item()/bincount同步(不同于loop MoE)。超容量token丢弃(capacity_factor控)。 等价测试(大capacity无丢弃==dense)。bench --moe-sparse/--moe-cap。默认关待验证。 Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -192,6 +192,25 @@ def test_varlen_matches_dense_attention():
|
||||
print(f"[PASS] varlen(嵌套张量) == 稠密SDPA (max err={err:.2e})")
|
||||
|
||||
|
||||
def test_sparse_moe_matches_dense():
|
||||
# 大 capacity(无丢弃)下,稀疏 MoE 应与 dense 数学等价
|
||||
torch.manual_seed(0)
|
||||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
m = infer.SMoE(d_model=512, dim_ff=1024, num_experts=8, k=2).to(dev).eval()
|
||||
x = torch.randn(1, 200, 512, device=dev)
|
||||
with torch.no_grad():
|
||||
infer.CONFIG["moe_sparse"] = False
|
||||
ref, _ = m(x)
|
||||
infer.CONFIG["moe_sparse"] = True
|
||||
infer.CONFIG["moe_capacity"] = 8.0 # 足够大,不丢 token
|
||||
new, _ = m(x)
|
||||
infer.CONFIG["moe_sparse"] = False
|
||||
infer.CONFIG["moe_capacity"] = 1.25
|
||||
err = (ref - new).abs().max().item()
|
||||
assert torch.allclose(ref, new, atol=1e-3, rtol=1e-3), f"sparse MoE 不等价 max err={err:.3e}"
|
||||
print(f"[PASS] sparse MoE(大capacity) == dense (max err={err:.2e}, dev={dev})")
|
||||
|
||||
|
||||
def test_fused_embedding_matches_perslot():
|
||||
torch.manual_seed(0)
|
||||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
@@ -240,6 +259,7 @@ def test_flex_matches_dense_attention():
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_moe_dense_matches_loop()
|
||||
test_sparse_moe_matches_dense()
|
||||
test_fused_embedding_matches_perslot()
|
||||
test_embedding_bag_matches()
|
||||
test_sparse_pool_matches()
|
||||
|
||||
Reference in New Issue
Block a user