diff --git a/代码/code/infer.py b/代码/code/infer.py index 9915f44..6077051 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -144,6 +144,7 @@ CONFIG = { # 须靠提交验证。AUC中性、MoE仅占2%算力故风险极低。 "vectorize_moe": True, # True=稠密向量化MoE(无同步点);False=原逐expert循环(.nonzero同步) "moe_baddbmm": True, # MoE FFN 用 baddbmm(cutlass GEMM+bias epilogue融合),省 bias add kernel + "moe_fused_weight": True, # top-k加权用scatter+mul+sum(在[E,N,D]上),省permute大clone+gather;数学等价 # 真稀疏MoE实测评测净负:lat 34.64->37.64s(本地快15%但argsort/scatter开销评测放大,如varlen) # +容量丢弃降AUC(0.7525->0.7507)。已退回 dense。 "moe_sparse": False, # True=真稀疏MoE(评测净负,勿开) @@ -758,11 +759,19 @@ class SMoE(nn.Module): o = torch.einsum("enf,edf->end", h, self.W2) + self.b2.unsqueeze(1) # 按每个 token 的 top-k 选取并加权(与逐 expert 循环数学等价) - o = o.permute(1, 0, 2) # [N, E, D] - idx = topk_idx.reshape(-1, self.k) # [N, k] - sc = topk_score.reshape(-1, self.k) # [N, k] - sel = torch.gather(o, 1, idx.unsqueeze(-1).expand(-1, -1, D)) # [N, k, D] - out = (sel * sc.unsqueeze(-1)).sum(dim=1).reshape(B, S, D) + if CONFIG.get("moe_fused_weight", True): + # 稀疏权重 [N,E],直接在 [E,N,D] 上加权求和(省掉 permute 的大 clone + gather) + idx = topk_idx.reshape(-1, self.k) # [N, k] + sc = topk_score.reshape(-1, self.k).to(o.dtype) # [N, k] + wfull = torch.zeros(Nt, self.num_experts, dtype=o.dtype, device=o.device) + wfull.scatter_(1, idx, sc) # [N,E] top-k 处=分数(索引互异,无冲突) + out = (o * wfull.t().unsqueeze(-1)).sum(0).reshape(B, S, D) # [E,N,D]*[E,N,1]->[N,D] + else: + o = o.permute(1, 0, 2) # [N, E, D] + idx = topk_idx.reshape(-1, self.k) # [N, k] + sc = topk_score.reshape(-1, self.k) # [N, k] + sel = torch.gather(o, 1, idx.unsqueeze(-1).expand(-1, -1, D)) # [N, k, D] + out = (sel * sc.unsqueeze(-1)).sum(dim=1).reshape(B, S, D) if CONFIG.get("skip_moe_loss", True): moe_loss = out.new_zeros(()) # 推理无用,跳过 importance/std/mean