perf: MoE top-k加权改scatter+mul+sum(在[E,N,D]上),省permute大clone+gather(profile clone 8%)

数学等价(top-k索引互异,scatter无冲突),零AUC风险。延续'减kernel'方向。
moe_fused_weight默认开,test_moe_dense_matches_loop已覆盖。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-19 20:22:16 +08:00
parent 6278d4a050
commit 9f73505caa
+14 -5
View File
@@ -144,6 +144,7 @@ CONFIG = {
# 须靠提交验证。AUC中性、MoE仅占2%算力故风险极低。 # 须靠提交验证。AUC中性、MoE仅占2%算力故风险极低。
"vectorize_moe": True, # True=稠密向量化MoE(无同步点)False=原逐expert循环(.nonzero同步) "vectorize_moe": True, # True=稠密向量化MoE(无同步点)False=原逐expert循环(.nonzero同步)
"moe_baddbmm": True, # MoE FFN 用 baddbmm(cutlass GEMM+bias epilogue融合),省 bias add kernel "moe_baddbmm": True, # MoE FFN 用 baddbmm(cutlass GEMM+bias epilogue融合),省 bias add kernel
"moe_fused_weight": True, # top-k加权用scatter+mul+sum(在[E,N,D]上),省permute大clone+gather;数学等价
# 真稀疏MoE实测评测净负:lat 34.64->37.64s(本地快15%但argsort/scatter开销评测放大,如varlen) # 真稀疏MoE实测评测净负:lat 34.64->37.64s(本地快15%但argsort/scatter开销评测放大,如varlen)
# +容量丢弃降AUC(0.7525->0.7507)。已退回 dense。 # +容量丢弃降AUC(0.7525->0.7507)。已退回 dense。
"moe_sparse": False, # True=真稀疏MoE(评测净负,勿开) "moe_sparse": False, # True=真稀疏MoE(评测净负,勿开)
@@ -758,11 +759,19 @@ class SMoE(nn.Module):
o = torch.einsum("enf,edf->end", h, self.W2) + self.b2.unsqueeze(1) o = torch.einsum("enf,edf->end", h, self.W2) + self.b2.unsqueeze(1)
# 按每个 token 的 top-k 选取并加权(与逐 expert 循环数学等价) # 按每个 token 的 top-k 选取并加权(与逐 expert 循环数学等价)
o = o.permute(1, 0, 2) # [N, E, D] if CONFIG.get("moe_fused_weight", True):
idx = topk_idx.reshape(-1, self.k) # [N, k] # 稀疏权重 [N,E],直接在 [E,N,D] 上加权求和(省掉 permute 的大 clone + gather
sc = topk_score.reshape(-1, self.k) # [N, k] idx = topk_idx.reshape(-1, self.k) # [N, k]
sel = torch.gather(o, 1, idx.unsqueeze(-1).expand(-1, -1, D)) # [N, k, D] sc = topk_score.reshape(-1, self.k).to(o.dtype) # [N, k]
out = (sel * sc.unsqueeze(-1)).sum(dim=1).reshape(B, S, D) wfull = torch.zeros(Nt, self.num_experts, dtype=o.dtype, device=o.device)
wfull.scatter_(1, idx, sc) # [N,E] top-k 处=分数(索引互异,无冲突)
out = (o * wfull.t().unsqueeze(-1)).sum(0).reshape(B, S, D) # [E,N,D]*[E,N,1]->[N,D]
else:
o = o.permute(1, 0, 2) # [N, E, D]
idx = topk_idx.reshape(-1, self.k) # [N, k]
sc = topk_score.reshape(-1, self.k) # [N, k]
sel = torch.gather(o, 1, idx.unsqueeze(-1).expand(-1, -1, D)) # [N, k, D]
out = (sel * sc.unsqueeze(-1)).sum(dim=1).reshape(B, S, D)
if CONFIG.get("skip_moe_loss", True): if CONFIG.get("skip_moe_loss", True):
moe_loss = out.new_zeros(()) # 推理无用,跳过 importance/std/mean moe_loss = out.new_zeros(()) # 推理无用,跳过 importance/std/mean