feat: 真稀疏MoE(capacity分组,只算top-k,cutlass baddbmm,无host同步)

按expert排序token+固定capacity分桶,每桶dense baddbmm,减GEMM~3x。argsort/where/
scatter/index_add无.item()/bincount同步(不同于loop MoE)。超容量token丢弃(capacity_factor控)。
等价测试(大capacity无丢弃==dense)。bench --moe-sparse/--moe-cap。默认关待验证。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-17 21:05:55 +08:00
parent aacfe904fd
commit b397c142fa
3 changed files with 64 additions and 0 deletions
+6
View File
@@ -347,6 +347,8 @@ def _parse_args():
ap.add_argument("--no-moe-baddbmm", action="store_true", help="关闭 MoE baddbmm(用 einsum 对照)") ap.add_argument("--no-moe-baddbmm", action="store_true", help="关闭 MoE baddbmm(用 einsum 对照)")
ap.add_argument("--no-skip-moe-loss", action="store_true", help="不跳过 moe_loss(对照)") ap.add_argument("--no-skip-moe-loss", action="store_true", help="不跳过 moe_loss(对照)")
ap.add_argument("--logit-bias", type=float, default=None, help="PCOC校准:logit偏移(本地验证PCOC→1.0)") ap.add_argument("--logit-bias", type=float, default=None, help="PCOC校准:logit偏移(本地验证PCOC→1.0)")
ap.add_argument("--moe-sparse", action="store_true", help="真稀疏MoE(只算top-k,capacity分组)")
ap.add_argument("--moe-cap", type=float, default=None, help="MoE capacity factor")
ap.add_argument("--sparse-pool", action="store_true", help="稀疏矩阵乘做池化(段内高重复时省)") ap.add_argument("--sparse-pool", action="store_true", help="稀疏矩阵乘做池化(段内高重复时省)")
ap.add_argument("--precompute-rep", action="store_true", ap.add_argument("--precompute-rep", action="store_true",
help="预计算RepEncoder缓存,model(batch)跳过embedding层(从batches自建)") help="预计算RepEncoder缓存,model(batch)跳过embedding层(从batches自建)")
@@ -401,6 +403,10 @@ if __name__ == "__main__":
cfg["skip_moe_loss"] = False cfg["skip_moe_loss"] = False
if a.logit_bias is not None: if a.logit_bias is not None:
cfg["logit_bias"] = a.logit_bias cfg["logit_bias"] = a.logit_bias
if a.moe_sparse:
cfg["moe_sparse"] = True
if a.moe_cap is not None:
cfg["moe_capacity"] = a.moe_cap
if a.sparse_pool: if a.sparse_pool:
cfg["sparse_pool"] = True cfg["sparse_pool"] = True
if a.precompute_rep: if a.precompute_rep:
+38
View File
@@ -144,6 +144,8 @@ CONFIG = {
# 须靠提交验证。AUC中性、MoE仅占2%算力故风险极低。 # 须靠提交验证。AUC中性、MoE仅占2%算力故风险极低。
"vectorize_moe": True, # True=稠密向量化MoE(无同步点)False=原逐expert循环(.nonzero同步) "vectorize_moe": True, # True=稠密向量化MoE(无同步点)False=原逐expert循环(.nonzero同步)
"moe_baddbmm": True, # MoE FFN 用 baddbmm(cutlass GEMM+bias epilogue融合),省 bias add kernel "moe_baddbmm": True, # MoE FFN 用 baddbmm(cutlass GEMM+bias epilogue融合),省 bias add kernel
"moe_sparse": False, # True=真稀疏MoE(只算top-k,capacity分组),减GEMM~3x;风险:开销/容量丢弃AUC
"moe_capacity": 1.25, # 每expert容量 = ceil(Nk/E*factor);越大越不丢token但计算越多
"skip_moe_loss": True, # 推理跳过 moe_loss(load-balance,推理无用),省 importance/std/mean kernel "skip_moe_loss": True, # 推理跳过 moe_loss(load-balance,推理无用),省 importance/std/mean kernel
# PCOC 校准:本地拟合-0.1067(本地PCOC1.109),但评测PCOC稳定1.059,按斜率换算评测最优≈-0.059。 # PCOC 校准:本地拟合-0.1067(本地PCOC1.109),但评测PCOC稳定1.059,按斜率换算评测最优≈-0.059。
"logit_bias": -0.06, # logit 加常数偏移使评测 PCOC→~1.0(单调,AUC不变,免费+~0.33分) "logit_bias": -0.06, # logit 加常数偏移使评测 PCOC→~1.0(单调,AUC不变,免费+~0.33分)
@@ -693,6 +695,39 @@ class SMoE(nn.Module):
self.register_buffer("W2t", self.W2.transpose(1, 2).contiguous()) # [E,F,D] self.register_buffer("W2t", self.W2.transpose(1, 2).contiguous()) # [E,F,D]
self._stacked = True self._stacked = True
def _forward_sparse(self, x):
"""真稀疏 MoE:每 token 只算 top-k expert(按 expert 排序 + capacity 分桶 + cutlass baddbmm)。
全程无 host 同步(argsort/where/scatter/index_add)。超容量 token 被丢弃(capacity_factor 控)。"""
import math
B, S, D = x.shape
topk_idx, topk_score, _ = self.gate(x)
N, k, E = B * S, self.k, self.num_experts
xf = x.reshape(N, D)
flat_e = topk_idx.reshape(-1) # [Nk] 每 pair 的 expert
flat_s = topk_score.reshape(-1) # [Nk]
Nk = flat_e.numel()
flat_t = torch.arange(N, device=x.device).repeat_interleave(k) # [Nk] token id
order = torch.argsort(flat_e) # 按 expert 排序(GPU sort,无 host 同步)
se, st, ss = flat_e[order], flat_t[order], flat_s[order]
xs = xf[st] # [Nk, D]
expert_start = torch.searchsorted(se.contiguous(),
torch.arange(E, device=x.device)) # [E]
pos_within = torch.arange(Nk, device=x.device) - expert_start[se] # 每 token 在其 expert 内位置
C = int(math.ceil(Nk / E * CONFIG.get("moe_capacity", 1.25)))
valid = pos_within < C
slot = se * C + pos_within
slot_safe = torch.where(valid, slot, torch.full_like(slot, E * C)) # 超容量→dummy 槽
buf = torch.zeros(E * C + 1, D, dtype=xs.dtype, device=x.device)
buf[slot_safe] = xs # scatterdummy 槽不读)
h = torch.baddbmm(self.b1.unsqueeze(1), buf[:E * C].view(E, C, D), self.W1t) # [E,C,F]
h = F.relu(h)
o = torch.baddbmm(self.b2.unsqueeze(1), h, self.W2t) # [E,C,D]
o_full = torch.cat([o.reshape(E * C, D),
torch.zeros(1, D, dtype=o.dtype, device=x.device)]) # [E*C+1, D]
out_s = o_full[slot_safe] * ss.unsqueeze(-1) # [Nk, D]dummy→0
out = torch.zeros(N, D, dtype=x.dtype, device=x.device).index_add_(0, st, out_s)
return out.view(B, S, D), out.new_zeros(())
def forward(self, x): def forward(self, x):
# x: [B,S,D] # x: [B,S,D]
if not CONFIG.get("vectorize_moe", True): if not CONFIG.get("vectorize_moe", True):
@@ -701,6 +736,9 @@ class SMoE(nn.Module):
if not self._stacked: if not self._stacked:
self._stack_weights() self._stack_weights()
if CONFIG.get("moe_sparse", False):
return self._forward_sparse(x)
B, S, D = x.shape B, S, D = x.shape
topk_idx, topk_score, probs = self.gate(x) topk_idx, topk_score, probs = self.gate(x)
+20
View File
@@ -192,6 +192,25 @@ def test_varlen_matches_dense_attention():
print(f"[PASS] varlen(嵌套张量) == 稠密SDPA (max err={err:.2e})") print(f"[PASS] varlen(嵌套张量) == 稠密SDPA (max err={err:.2e})")
def test_sparse_moe_matches_dense():
# 大 capacity(无丢弃)下,稀疏 MoE 应与 dense 数学等价
torch.manual_seed(0)
dev = "cuda" if torch.cuda.is_available() else "cpu"
m = infer.SMoE(d_model=512, dim_ff=1024, num_experts=8, k=2).to(dev).eval()
x = torch.randn(1, 200, 512, device=dev)
with torch.no_grad():
infer.CONFIG["moe_sparse"] = False
ref, _ = m(x)
infer.CONFIG["moe_sparse"] = True
infer.CONFIG["moe_capacity"] = 8.0 # 足够大,不丢 token
new, _ = m(x)
infer.CONFIG["moe_sparse"] = False
infer.CONFIG["moe_capacity"] = 1.25
err = (ref - new).abs().max().item()
assert torch.allclose(ref, new, atol=1e-3, rtol=1e-3), f"sparse MoE 不等价 max err={err:.3e}"
print(f"[PASS] sparse MoE(大capacity) == dense (max err={err:.2e}, dev={dev})")
def test_fused_embedding_matches_perslot(): def test_fused_embedding_matches_perslot():
torch.manual_seed(0) torch.manual_seed(0)
dev = "cuda" if torch.cuda.is_available() else "cpu" dev = "cuda" if torch.cuda.is_available() else "cpu"
@@ -240,6 +259,7 @@ def test_flex_matches_dense_attention():
if __name__ == "__main__": if __name__ == "__main__":
test_moe_dense_matches_loop() test_moe_dense_matches_loop()
test_sparse_moe_matches_dense()
test_fused_embedding_matches_perslot() test_fused_embedding_matches_perslot()
test_embedding_bag_matches() test_embedding_bag_matches()
test_sparse_pool_matches() test_sparse_pool_matches()