From 575b32f263a473d17a246a8b69c823008be335e9 Mon Sep 17 00:00:00 2001 From: OwnerSunshine530 Date: Wed, 17 Jun 2026 14:27:59 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20fused=20MoE=20=E2=80=94=20baddbmm(cutla?= =?UTF-8?q?ss=20GEMM+bias=E8=9E=8D=E5=90=88)+=E8=B7=B3=E8=BF=87=E6=8E=A8?= =?UTF-8?q?=E7=90=86=E6=97=A0=E7=94=A8=E7=9A=84moe=5Floss,=E5=87=8Fkernel?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit GEMM保留cutlass(triton GEMM难超),融bias epilogue省add kernel;moe_loss仅训练用, 推理跳过省importance/std/mean。延续减kernel方向(embedding_bag/triton已证评测赚)。 默认开,bench --no-moe-baddbmm/--no-skip-moe-loss 对照。AUC无损。 Co-Authored-By: Claude Opus 4.8 --- 代码/code/bench.py | 6 ++++++ 代码/code/infer.py | 27 +++++++++++++++++++++------ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/代码/code/bench.py b/代码/code/bench.py index 27308dd..6f4b237 100644 --- a/代码/code/bench.py +++ b/代码/code/bench.py @@ -325,6 +325,8 @@ def _parse_args(): ap.add_argument("--emb-fp16", action="store_true", help="Embedding表转FP16(查表带宽减半,测AUC)") ap.add_argument("--dedup-emb", action="store_true", help="查表前对sign去重(减少大表随机访存)") ap.add_argument("--emb-bag", action="store_true", help="F.embedding_bag 融合查表+池化") + ap.add_argument("--no-moe-baddbmm", action="store_true", help="关闭 MoE baddbmm(用 einsum 对照)") + ap.add_argument("--no-skip-moe-loss", action="store_true", help="不跳过 moe_loss(对照)") ap.add_argument("--sparse-pool", action="store_true", help="稀疏矩阵乘做池化(段内高重复时省)") ap.add_argument("--precompute-rep", action="store_true", help="预计算RepEncoder缓存,model(batch)跳过embedding层(从batches自建)") @@ -373,6 +375,10 @@ if __name__ == "__main__": cfg["dedup_embedding"] = True if a.emb_bag: cfg["use_embedding_bag"] = True + if a.no_moe_baddbmm: + cfg["moe_baddbmm"] = False + if a.no_skip_moe_loss: + cfg["skip_moe_loss"] = False if a.sparse_pool: cfg["sparse_pool"] = True if a.precompute_rep: diff --git a/代码/code/infer.py b/代码/code/infer.py index 4dce083..b4e438e 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -143,6 +143,8 @@ CONFIG = { # synchronize,去掉同步点可能让被计时的 model(batch) 大幅缩短。本地force-sync看不出, # 须靠提交验证。AUC中性、MoE仅占2%算力故风险极低。 "vectorize_moe": True, # True=稠密向量化MoE(无同步点);False=原逐expert循环(.nonzero同步) + "moe_baddbmm": True, # MoE FFN 用 baddbmm(cutlass GEMM+bias epilogue融合),省 bias add kernel + "skip_moe_loss": True, # 推理跳过 moe_loss(load-balance,推理无用),省 importance/std/mean kernel "fuse_embedding": True, # True=28个slot的查表+池化融合为1次(减per-batch kernel启动) "syncfree_mask": True, # True=用searchsorted构造因果mask(无同步);False=repeat_interleave(同步) "emb_fp16": True, # True=Embedding表转FP16(查表带宽减半,实测AUC 0.75932≈无损) @@ -684,6 +686,9 @@ class SMoE(nn.Module): self.register_buffer("b1", torch.stack([e.fc1.bias for e in self.experts]).contiguous()) # [E,F] self.register_buffer("W2", torch.stack([e.fc2.weight for e in self.experts]).contiguous()) # [E,D,F] self.register_buffer("b2", torch.stack([e.fc2.bias for e in self.experts]).contiguous()) # [E,D] + # baddbmm 用的转置权重([E,D,F] / [E,F,D]),预转 contiguous + self.register_buffer("W1t", self.W1.transpose(1, 2).contiguous()) # [E,D,F] + self.register_buffer("W2t", self.W2.transpose(1, 2).contiguous()) # [E,F,D] self._stacked = True def forward(self, x): @@ -698,10 +703,17 @@ class SMoE(nn.Module): topk_idx, topk_score, probs = self.gate(x) xf = x.reshape(-1, D) # [N, D] - # 稠密计算所有 expert(GPU 友好、无 Python 循环/同步/gather-scatter): - h = torch.einsum("nd,efd->enf", xf, self.W1) + self.b1.unsqueeze(1) # [E,N,F] - h = F.relu(h) - o = torch.einsum("enf,edf->end", h, self.W2) + self.b2.unsqueeze(1) # [E,N,D] + Nt = xf.shape[0] + if CONFIG.get("moe_baddbmm", True): + # cutlass GEMM + bias epilogue 融合(省 bias add kernel) + xe = xf.unsqueeze(0).expand(self.num_experts, -1, -1) # [E,N,D] + h = torch.baddbmm(self.b1.unsqueeze(1), xe, self.W1t) # [E,N,F] + h = F.relu(h) + o = torch.baddbmm(self.b2.unsqueeze(1), h, self.W2t) # [E,N,D] + else: + h = torch.einsum("nd,efd->enf", xf, self.W1) + self.b1.unsqueeze(1) + h = F.relu(h) + o = torch.einsum("enf,edf->end", h, self.W2) + self.b2.unsqueeze(1) # 按每个 token 的 top-k 选取并加权(与逐 expert 循环数学等价) o = o.permute(1, 0, 2) # [N, E, D] @@ -710,8 +722,11 @@ class SMoE(nn.Module): sel = torch.gather(o, 1, idx.unsqueeze(-1).expand(-1, -1, D)) # [N, k, D] out = (sel * sc.unsqueeze(-1)).sum(dim=1).reshape(B, S, D) - importance = probs.sum(dim=(0, 1)) # [E] - moe_loss = (importance.std() / (importance.mean() + 1e-6)) + if CONFIG.get("skip_moe_loss", True): + moe_loss = out.new_zeros(()) # 推理无用,跳过 importance/std/mean + else: + importance = probs.sum(dim=(0, 1)) # [E] + moe_loss = (importance.std() / (importance.mean() + 1e-6)) return out, moe_loss