feat: fused MoE — baddbmm(cutlass GEMM+bias融合)+跳过推理无用的moe_loss,减kernel
GEMM保留cutlass(triton GEMM难超),融bias epilogue省add kernel;moe_loss仅训练用, 推理跳过省importance/std/mean。延续减kernel方向(embedding_bag/triton已证评测赚)。 默认开,bench --no-moe-baddbmm/--no-skip-moe-loss 对照。AUC无损。 Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -325,6 +325,8 @@ def _parse_args():
|
||||
ap.add_argument("--emb-fp16", action="store_true", help="Embedding表转FP16(查表带宽减半,测AUC)")
|
||||
ap.add_argument("--dedup-emb", action="store_true", help="查表前对sign去重(减少大表随机访存)")
|
||||
ap.add_argument("--emb-bag", action="store_true", help="F.embedding_bag 融合查表+池化")
|
||||
ap.add_argument("--no-moe-baddbmm", action="store_true", help="关闭 MoE baddbmm(用 einsum 对照)")
|
||||
ap.add_argument("--no-skip-moe-loss", action="store_true", help="不跳过 moe_loss(对照)")
|
||||
ap.add_argument("--sparse-pool", action="store_true", help="稀疏矩阵乘做池化(段内高重复时省)")
|
||||
ap.add_argument("--precompute-rep", action="store_true",
|
||||
help="预计算RepEncoder缓存,model(batch)跳过embedding层(从batches自建)")
|
||||
@@ -373,6 +375,10 @@ if __name__ == "__main__":
|
||||
cfg["dedup_embedding"] = True
|
||||
if a.emb_bag:
|
||||
cfg["use_embedding_bag"] = True
|
||||
if a.no_moe_baddbmm:
|
||||
cfg["moe_baddbmm"] = False
|
||||
if a.no_skip_moe_loss:
|
||||
cfg["skip_moe_loss"] = False
|
||||
if a.sparse_pool:
|
||||
cfg["sparse_pool"] = True
|
||||
if a.precompute_rep:
|
||||
|
||||
+18
-3
@@ -143,6 +143,8 @@ CONFIG = {
|
||||
# synchronize,去掉同步点可能让被计时的 model(batch) 大幅缩短。本地force-sync看不出,
|
||||
# 须靠提交验证。AUC中性、MoE仅占2%算力故风险极低。
|
||||
"vectorize_moe": True, # True=稠密向量化MoE(无同步点);False=原逐expert循环(.nonzero同步)
|
||||
"moe_baddbmm": True, # MoE FFN 用 baddbmm(cutlass GEMM+bias epilogue融合),省 bias add kernel
|
||||
"skip_moe_loss": True, # 推理跳过 moe_loss(load-balance,推理无用),省 importance/std/mean kernel
|
||||
"fuse_embedding": True, # True=28个slot的查表+池化融合为1次(减per-batch kernel启动)
|
||||
"syncfree_mask": True, # True=用searchsorted构造因果mask(无同步);False=repeat_interleave(同步)
|
||||
"emb_fp16": True, # True=Embedding表转FP16(查表带宽减半,实测AUC 0.75932≈无损)
|
||||
@@ -684,6 +686,9 @@ class SMoE(nn.Module):
|
||||
self.register_buffer("b1", torch.stack([e.fc1.bias for e in self.experts]).contiguous()) # [E,F]
|
||||
self.register_buffer("W2", torch.stack([e.fc2.weight for e in self.experts]).contiguous()) # [E,D,F]
|
||||
self.register_buffer("b2", torch.stack([e.fc2.bias for e in self.experts]).contiguous()) # [E,D]
|
||||
# baddbmm 用的转置权重([E,D,F] / [E,F,D]),预转 contiguous
|
||||
self.register_buffer("W1t", self.W1.transpose(1, 2).contiguous()) # [E,D,F]
|
||||
self.register_buffer("W2t", self.W2.transpose(1, 2).contiguous()) # [E,F,D]
|
||||
self._stacked = True
|
||||
|
||||
def forward(self, x):
|
||||
@@ -698,10 +703,17 @@ class SMoE(nn.Module):
|
||||
topk_idx, topk_score, probs = self.gate(x)
|
||||
|
||||
xf = x.reshape(-1, D) # [N, D]
|
||||
# 稠密计算所有 expert(GPU 友好、无 Python 循环/同步/gather-scatter):
|
||||
h = torch.einsum("nd,efd->enf", xf, self.W1) + self.b1.unsqueeze(1) # [E,N,F]
|
||||
Nt = xf.shape[0]
|
||||
if CONFIG.get("moe_baddbmm", True):
|
||||
# cutlass GEMM + bias epilogue 融合(省 bias add kernel)
|
||||
xe = xf.unsqueeze(0).expand(self.num_experts, -1, -1) # [E,N,D]
|
||||
h = torch.baddbmm(self.b1.unsqueeze(1), xe, self.W1t) # [E,N,F]
|
||||
h = F.relu(h)
|
||||
o = torch.einsum("enf,edf->end", h, self.W2) + self.b2.unsqueeze(1) # [E,N,D]
|
||||
o = torch.baddbmm(self.b2.unsqueeze(1), h, self.W2t) # [E,N,D]
|
||||
else:
|
||||
h = torch.einsum("nd,efd->enf", xf, self.W1) + self.b1.unsqueeze(1)
|
||||
h = F.relu(h)
|
||||
o = torch.einsum("enf,edf->end", h, self.W2) + self.b2.unsqueeze(1)
|
||||
|
||||
# 按每个 token 的 top-k 选取并加权(与逐 expert 循环数学等价)
|
||||
o = o.permute(1, 0, 2) # [N, E, D]
|
||||
@@ -710,6 +722,9 @@ class SMoE(nn.Module):
|
||||
sel = torch.gather(o, 1, idx.unsqueeze(-1).expand(-1, -1, D)) # [N, k, D]
|
||||
out = (sel * sc.unsqueeze(-1)).sum(dim=1).reshape(B, S, D)
|
||||
|
||||
if CONFIG.get("skip_moe_loss", True):
|
||||
moe_loss = out.new_zeros(()) # 推理无用,跳过 importance/std/mean
|
||||
else:
|
||||
importance = probs.sum(dim=(0, 1)) # [E]
|
||||
moe_loss = (importance.std() / (importance.mean() + 1e-6))
|
||||
return out, moe_loss
|
||||
|
||||
Reference in New Issue
Block a user