feat: MoE Top-1 gating(每个 token 仅激活 1 个 expert,FFN 计算减半)

- SMoE 默认 k=2 → k=1(属于稀疏优化,规则允许)
- TransformerEncoder 8 层全部改用 Top-1 gating
- forward 针对 k=1 走快速路径(避免二维 mask 和加权累加)
This commit is contained in:
2026-06-12 22:04:34 +08:00
parent bc6e8307c5
commit feb71be5bd
+25 -17
View File
@@ -320,7 +320,7 @@ class TopKGate(nn.Module):
return topk_idx, topk_score, probs
class SMoE(nn.Module):
def __init__(self, d_model, dim_ff, num_experts, k=2):
def __init__(self, d_model, dim_ff, num_experts, k=1):
super().__init__()
self.num_experts = num_experts
self.k = k
@@ -337,29 +337,37 @@ class SMoE(nn.Module):
topk_idx, topk_score, probs = self.gate(x)
out = torch.zeros_like(x)
# flatten
x_flat = x.reshape(-1, D) # [B*S, D]
idx_flat = topk_idx.reshape(-1, self.k) # [B*S, k]
# flatten: [B, S, k] → [B*S, k]
x_flat = x.reshape(-1, D)
idx_flat = topk_idx.reshape(-1, self.k)
score_flat = topk_score.reshape(-1, self.k)
for i in range(self.num_experts):
# 找到被路由到 expert i 的 token
mask = (idx_flat == i) # [B*S, k]
if self.k == 1:
# Top-1 快速路径:无需二维 mask 和加权累加
idx_flat = idx_flat.squeeze(-1) # [B*S]
score_flat = score_flat.squeeze(-1) # [B*S]
out = torch.zeros_like(x_flat)
for i in range(self.num_experts):
mask = (idx_flat == i) # [B*S]
if not mask.any():
continue
selected_x = x_flat[mask]
expert_out = self.experts[i](selected_x)
out[mask] = expert_out * score_flat[mask].unsqueeze(-1)
# 哪些 token 命中了 expert i
out = out.reshape(B, S, D)
else:
# Top-K 通用路径(k > 1
out = torch.zeros_like(x)
for i in range(self.num_experts):
mask = (idx_flat == i) # [B*S, k]
if not mask.any():
continue
token_idx, k_idx = mask.nonzero(as_tuple=True)
selected_x = x_flat[token_idx] # [N, D]
expert_out = self.experts[i](selected_x) # [N, D]
selected_x = x_flat[token_idx]
expert_out = self.experts[i](selected_x)
weight = score_flat[token_idx, k_idx].unsqueeze(-1)
out_flat = out.reshape(-1, D)
out_flat[token_idx] += expert_out * weight
@@ -388,7 +396,7 @@ class TransformerEncoder(nn.Module):
self.act = getattr(F, act)
self.attention_fn = attention_fn
self.moe = nn.ModuleList([
SMoE(d_model, dim_ff, num_experts=8, k=2)
SMoE(d_model, dim_ff, num_experts=8, k=1) # Top-1 gating: 每个 token 仅激活 1 个 expert
for _ in range(num_layers)
])