From 9f73505caa766b3bc94f8a4a83f90160579f4671 Mon Sep 17 00:00:00 2001 From: OwnerSunshine530 Date: Fri, 19 Jun 2026 20:22:16 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20MoE=20top-k=E5=8A=A0=E6=9D=83=E6=94=B9s?= =?UTF-8?q?catter+mul+sum(=E5=9C=A8[E,N,D]=E4=B8=8A),=E7=9C=81permute?= =?UTF-8?q?=E5=A4=A7clone+gather(profile=20clone=208%)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 数学等价(top-k索引互异,scatter无冲突),零AUC风险。延续'减kernel'方向。 moe_fused_weight默认开,test_moe_dense_matches_loop已覆盖。 Co-Authored-By: Claude Opus 4.8 --- 代码/code/infer.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/代码/code/infer.py b/代码/code/infer.py index 9915f44..6077051 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -144,6 +144,7 @@ CONFIG = { # 须靠提交验证。AUC中性、MoE仅占2%算力故风险极低。 "vectorize_moe": True, # True=稠密向量化MoE(无同步点);False=原逐expert循环(.nonzero同步) "moe_baddbmm": True, # MoE FFN 用 baddbmm(cutlass GEMM+bias epilogue融合),省 bias add kernel + "moe_fused_weight": True, # top-k加权用scatter+mul+sum(在[E,N,D]上),省permute大clone+gather;数学等价 # 真稀疏MoE实测评测净负:lat 34.64->37.64s(本地快15%但argsort/scatter开销评测放大,如varlen) # +容量丢弃降AUC(0.7525->0.7507)。已退回 dense。 "moe_sparse": False, # True=真稀疏MoE(评测净负,勿开) @@ -758,11 +759,19 @@ class SMoE(nn.Module): o = torch.einsum("enf,edf->end", h, self.W2) + self.b2.unsqueeze(1) # 按每个 token 的 top-k 选取并加权(与逐 expert 循环数学等价) - o = o.permute(1, 0, 2) # [N, E, D] - idx = topk_idx.reshape(-1, self.k) # [N, k] - sc = topk_score.reshape(-1, self.k) # [N, k] - sel = torch.gather(o, 1, idx.unsqueeze(-1).expand(-1, -1, D)) # [N, k, D] - out = (sel * sc.unsqueeze(-1)).sum(dim=1).reshape(B, S, D) + if CONFIG.get("moe_fused_weight", True): + # 稀疏权重 [N,E],直接在 [E,N,D] 上加权求和(省掉 permute 的大 clone + gather) + idx = topk_idx.reshape(-1, self.k) # [N, k] + sc = topk_score.reshape(-1, self.k).to(o.dtype) # [N, k] + wfull = torch.zeros(Nt, self.num_experts, dtype=o.dtype, device=o.device) + wfull.scatter_(1, idx, sc) # [N,E] top-k 处=分数(索引互异,无冲突) + out = (o * wfull.t().unsqueeze(-1)).sum(0).reshape(B, S, D) # [E,N,D]*[E,N,1]->[N,D] + else: + o = o.permute(1, 0, 2) # [N, E, D] + idx = topk_idx.reshape(-1, self.k) # [N, k] + sc = topk_score.reshape(-1, self.k) # [N, k] + sel = torch.gather(o, 1, idx.unsqueeze(-1).expand(-1, -1, D)) # [N, k, D] + out = (sel * sc.unsqueeze(-1)).sum(dim=1).reshape(B, S, D) if CONFIG.get("skip_moe_loss", True): moe_loss = out.new_zeros(()) # 推理无用,跳过 importance/std/mean