feat: fused MoE — baddbmm(cutlass GEMM+bias融合)+跳过推理无用的moe_loss,减kernel

GEMM保留cutlass(triton GEMM难超),融bias epilogue省add kernel;moe_loss仅训练用,
推理跳过省importance/std/mean。延续减kernel方向(embedding_bag/triton已证评测赚)。
默认开,bench --no-moe-baddbmm/--no-skip-moe-loss 对照。AUC无损。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-17 14:27:59 +08:00
parent 6bb51a1057
commit 575b32f263
2 changed files with 27 additions and 6 deletions
+21 -6
View File
@@ -143,6 +143,8 @@ CONFIG = {
# synchronize,去掉同步点可能让被计时的 model(batch) 大幅缩短。本地force-sync看不出,
# 须靠提交验证。AUC中性、MoE仅占2%算力故风险极低。
"vectorize_moe": True, # True=稠密向量化MoE(无同步点)False=原逐expert循环(.nonzero同步)
"moe_baddbmm": True, # MoE FFN 用 baddbmm(cutlass GEMM+bias epilogue融合),省 bias add kernel
"skip_moe_loss": True, # 推理跳过 moe_loss(load-balance,推理无用),省 importance/std/mean kernel
"fuse_embedding": True, # True=28个slot的查表+池化融合为1次(减per-batch kernel启动)
"syncfree_mask": True, # True=用searchsorted构造因果mask(无同步)False=repeat_interleave(同步)
"emb_fp16": True, # True=Embedding表转FP16(查表带宽减半,实测AUC 0.75932≈无损)
@@ -684,6 +686,9 @@ class SMoE(nn.Module):
self.register_buffer("b1", torch.stack([e.fc1.bias for e in self.experts]).contiguous()) # [E,F]
self.register_buffer("W2", torch.stack([e.fc2.weight for e in self.experts]).contiguous()) # [E,D,F]
self.register_buffer("b2", torch.stack([e.fc2.bias for e in self.experts]).contiguous()) # [E,D]
# baddbmm 用的转置权重([E,D,F] / [E,F,D]),预转 contiguous
self.register_buffer("W1t", self.W1.transpose(1, 2).contiguous()) # [E,D,F]
self.register_buffer("W2t", self.W2.transpose(1, 2).contiguous()) # [E,F,D]
self._stacked = True
def forward(self, x):
@@ -698,10 +703,17 @@ class SMoE(nn.Module):
topk_idx, topk_score, probs = self.gate(x)
xf = x.reshape(-1, D) # [N, D]
# 稠密计算所有 expertGPU 友好、无 Python 循环/同步/gather-scatter):
h = torch.einsum("nd,efd->enf", xf, self.W1) + self.b1.unsqueeze(1) # [E,N,F]
h = F.relu(h)
o = torch.einsum("enf,edf->end", h, self.W2) + self.b2.unsqueeze(1) # [E,N,D]
Nt = xf.shape[0]
if CONFIG.get("moe_baddbmm", True):
# cutlass GEMM + bias epilogue 融合(省 bias add kernel
xe = xf.unsqueeze(0).expand(self.num_experts, -1, -1) # [E,N,D]
h = torch.baddbmm(self.b1.unsqueeze(1), xe, self.W1t) # [E,N,F]
h = F.relu(h)
o = torch.baddbmm(self.b2.unsqueeze(1), h, self.W2t) # [E,N,D]
else:
h = torch.einsum("nd,efd->enf", xf, self.W1) + self.b1.unsqueeze(1)
h = F.relu(h)
o = torch.einsum("enf,edf->end", h, self.W2) + self.b2.unsqueeze(1)
# 按每个 token 的 top-k 选取并加权(与逐 expert 循环数学等价)
o = o.permute(1, 0, 2) # [N, E, D]
@@ -710,8 +722,11 @@ class SMoE(nn.Module):
sel = torch.gather(o, 1, idx.unsqueeze(-1).expand(-1, -1, D)) # [N, k, D]
out = (sel * sc.unsqueeze(-1)).sum(dim=1).reshape(B, S, D)
importance = probs.sum(dim=(0, 1)) # [E]
moe_loss = (importance.std() / (importance.mean() + 1e-6))
if CONFIG.get("skip_moe_loss", True):
moe_loss = out.new_zeros(()) # 推理无用,跳过 importance/std/mean
else:
importance = probs.sum(dim=(0, 1)) # [E]
moe_loss = (importance.std() / (importance.mean() + 1e-6))
return out, moe_loss