feat: MoE Top-1 路由 + (p1+p2) 权重补偿

- 仅路由到 Top-1 expert(节省 50% FFN 计算)
- gate 输出 top-2 概率,用 p1+p2 作为输出权重
- 近似 k=2 的输出幅度,避免 PCOC 偏移
- 是参数调整修正,非方案本身错误
This commit is contained in:
2026-06-13 13:32:04 +08:00
parent b991f9e78e
commit c081620ffd
+17 -18
View File
@@ -320,7 +320,7 @@ class TopKGate(nn.Module):
return topk_idx, topk_score, probs return topk_idx, topk_score, probs
class SMoE(nn.Module): class SMoE(nn.Module):
def __init__(self, d_model, dim_ff, num_experts, k=2): def __init__(self, d_model, dim_ff, num_experts, k=1):
super().__init__() super().__init__()
self.num_experts = num_experts self.num_experts = num_experts
self.k = k self.k = k
@@ -329,35 +329,34 @@ class SMoE(nn.Module):
Expert(d_model, dim_ff) for _ in range(num_experts) Expert(d_model, dim_ff) for _ in range(num_experts)
]) ])
self.gate = TopKGate(d_model, num_experts, k=k) self.gate = TopKGate(d_model, num_experts, k=2) # gate 内部用 k=2 获取补偿权重
def forward(self, x): def forward(self, x):
# x: [B,S,D] # x: [B,S,D]
B, S, D = x.shape B, S, D = x.shape
topk_idx, topk_score, probs = self.gate(x) topk_idx, topk_score, probs = self.gate(x)
# topk_idx: [B, S, 2], topk_score: [B, S, 2]
# 仅路由到 Top-1 expert,但用 (p1+p2) 作为权重补偿
route_idx = topk_idx[:, :, :1] # [B, S, 1] — 只取 top-1
weight_sum = topk_score.sum(dim=-1) # [B, S] — p1 + p2 作为总权重
out = torch.zeros_like(x) out = torch.zeros_like(x)
# flatten x_flat = x.reshape(-1, D)
x_flat = x.reshape(-1, D) # [B*S, D] idx_flat = route_idx.reshape(-1) # [B*S]
idx_flat = topk_idx.reshape(-1, self.k) # [B*S, k] weight_flat = weight_sum.reshape(-1) # [B*S]
score_flat = topk_score.reshape(-1, self.k) out_flat = out.reshape(-1, D)
out_flat = out.reshape(-1, D) # 提前 reshape,避免循环内重复
for i in range(self.num_experts): for i in range(self.num_experts):
# 找到被路由到 expert i 的 token mask = (idx_flat == i) # [B*S]
mask = (idx_flat == i) # [B*S, k] token_idx = mask.nonzero(as_tuple=True)[0]
# 注:k=2 时几乎所有 expert 都分到 token,移除 .any() 检查避免 GPU 同步
token_idx, k_idx = mask.nonzero(as_tuple=True)
if token_idx.numel() == 0: if token_idx.numel() == 0:
continue continue
selected_x = x_flat[token_idx]
selected_x = x_flat[token_idx] # [N, D] expert_out = self.experts[i](selected_x)
expert_out = self.experts[i](selected_x) # [N, D] out_flat[token_idx] = expert_out * weight_flat[token_idx].unsqueeze(-1)
weight = score_flat[token_idx, k_idx].unsqueeze(-1)
out_flat[token_idx] += expert_out * weight
importance = probs.sum(dim=(0,1)) # [E] importance = probs.sum(dim=(0,1)) # [E]
moe_loss = (importance.std() / (importance.mean() + 1e-6)) moe_loss = (importance.std() / (importance.mean() + 1e-6))
@@ -384,7 +383,7 @@ class TransformerEncoder(nn.Module):
self.act = getattr(F, act) self.act = getattr(F, act)
self.attention_fn = attention_fn self.attention_fn = attention_fn
self.moe = nn.ModuleList([ self.moe = nn.ModuleList([
SMoE(d_model, dim_ff, num_experts=8, k=2) SMoE(d_model, dim_ff, num_experts=8, k=1) # Top-1 路由 + (p1+p2) 权重补偿
for _ in range(num_layers) for _ in range(num_layers)
]) ])