From feb71be5bd4677083774c30f0271c56fcb39ce0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E8=88=AA=E5=AE=87?= <3364451258@qq.com> Date: Fri, 12 Jun 2026 22:04:34 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20MoE=20Top-1=20gating=EF=BC=88=E6=AF=8F?= =?UTF-8?q?=E4=B8=AA=20token=20=E4=BB=85=E6=BF=80=E6=B4=BB=201=20=E4=B8=AA?= =?UTF-8?q?=20expert=EF=BC=8CFFN=20=E8=AE=A1=E7=AE=97=E5=87=8F=E5=8D=8A?= =?UTF-8?q?=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - SMoE 默认 k=2 → k=1(属于稀疏优化,规则允许) - TransformerEncoder 8 层全部改用 Top-1 gating - forward 针对 k=1 走快速路径(避免二维 mask 和加权累加) --- 代码/code/infer.py | 54 ++++++++++++++++++++++++++-------------------- 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/代码/code/infer.py b/代码/code/infer.py index 09e3a44..29ff02a 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -320,7 +320,7 @@ class TopKGate(nn.Module): return topk_idx, topk_score, probs class SMoE(nn.Module): - def __init__(self, d_model, dim_ff, num_experts, k=2): + def __init__(self, d_model, dim_ff, num_experts, k=1): super().__init__() self.num_experts = num_experts self.k = k @@ -337,31 +337,39 @@ class SMoE(nn.Module): topk_idx, topk_score, probs = self.gate(x) - out = torch.zeros_like(x) - - # flatten - x_flat = x.reshape(-1, D) # [B*S, D] - idx_flat = topk_idx.reshape(-1, self.k) # [B*S, k] + # flatten: [B, S, k] → [B*S, k] + x_flat = x.reshape(-1, D) + idx_flat = topk_idx.reshape(-1, self.k) score_flat = topk_score.reshape(-1, self.k) - for i in range(self.num_experts): - # 找到被路由到 expert i 的 token - mask = (idx_flat == i) # [B*S, k] + if self.k == 1: + # Top-1 快速路径:无需二维 mask 和加权累加 + idx_flat = idx_flat.squeeze(-1) # [B*S] + score_flat = score_flat.squeeze(-1) # [B*S] + out = torch.zeros_like(x_flat) - if not mask.any(): - continue + for i in range(self.num_experts): + mask = (idx_flat == i) # [B*S] + if not mask.any(): + continue + selected_x = x_flat[mask] + expert_out = self.experts[i](selected_x) + out[mask] = expert_out * score_flat[mask].unsqueeze(-1) - # 哪些 token 命中了 expert i - token_idx, k_idx = mask.nonzero(as_tuple=True) - - selected_x = x_flat[token_idx] # [N, D] - - expert_out = self.experts[i](selected_x) # [N, D] - - weight = score_flat[token_idx, k_idx].unsqueeze(-1) - - out_flat = out.reshape(-1, D) - out_flat[token_idx] += expert_out * weight + out = out.reshape(B, S, D) + else: + # Top-K 通用路径(k > 1) + out = torch.zeros_like(x) + for i in range(self.num_experts): + mask = (idx_flat == i) # [B*S, k] + if not mask.any(): + continue + token_idx, k_idx = mask.nonzero(as_tuple=True) + selected_x = x_flat[token_idx] + expert_out = self.experts[i](selected_x) + weight = score_flat[token_idx, k_idx].unsqueeze(-1) + out_flat = out.reshape(-1, D) + out_flat[token_idx] += expert_out * weight importance = probs.sum(dim=(0,1)) # [E] moe_loss = (importance.std() / (importance.mean() + 1e-6)) @@ -388,7 +396,7 @@ class TransformerEncoder(nn.Module): self.act = getattr(F, act) self.attention_fn = attention_fn self.moe = nn.ModuleList([ - SMoE(d_model, dim_ff, num_experts=8, k=2) + SMoE(d_model, dim_ff, num_experts=8, k=1) # Top-1 gating: 每个 token 仅激活 1 个 expert for _ in range(num_layers) ])