revert: MoE k=1 → k=2(PCOC 从 1.059 炸到 2.075,Top-1 破坏输出校准)
保留 inference_mode + torch.compile(default)
This commit is contained in:
+23
-31
@@ -320,7 +320,7 @@ class TopKGate(nn.Module):
|
|||||||
return topk_idx, topk_score, probs
|
return topk_idx, topk_score, probs
|
||||||
|
|
||||||
class SMoE(nn.Module):
|
class SMoE(nn.Module):
|
||||||
def __init__(self, d_model, dim_ff, num_experts, k=1):
|
def __init__(self, d_model, dim_ff, num_experts, k=2):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
self.k = k
|
self.k = k
|
||||||
@@ -337,39 +337,31 @@ class SMoE(nn.Module):
|
|||||||
|
|
||||||
topk_idx, topk_score, probs = self.gate(x)
|
topk_idx, topk_score, probs = self.gate(x)
|
||||||
|
|
||||||
# flatten: [B, S, k] → [B*S, k]
|
out = torch.zeros_like(x)
|
||||||
x_flat = x.reshape(-1, D)
|
|
||||||
idx_flat = topk_idx.reshape(-1, self.k)
|
# flatten
|
||||||
|
x_flat = x.reshape(-1, D) # [B*S, D]
|
||||||
|
idx_flat = topk_idx.reshape(-1, self.k) # [B*S, k]
|
||||||
score_flat = topk_score.reshape(-1, self.k)
|
score_flat = topk_score.reshape(-1, self.k)
|
||||||
|
|
||||||
if self.k == 1:
|
for i in range(self.num_experts):
|
||||||
# Top-1 快速路径:无需二维 mask 和加权累加
|
# 找到被路由到 expert i 的 token
|
||||||
idx_flat = idx_flat.squeeze(-1) # [B*S]
|
mask = (idx_flat == i) # [B*S, k]
|
||||||
score_flat = score_flat.squeeze(-1) # [B*S]
|
|
||||||
out = torch.zeros_like(x_flat)
|
|
||||||
|
|
||||||
for i in range(self.num_experts):
|
if not mask.any():
|
||||||
mask = (idx_flat == i) # [B*S]
|
continue
|
||||||
if not mask.any():
|
|
||||||
continue
|
|
||||||
selected_x = x_flat[mask]
|
|
||||||
expert_out = self.experts[i](selected_x)
|
|
||||||
out[mask] = expert_out * score_flat[mask].unsqueeze(-1)
|
|
||||||
|
|
||||||
out = out.reshape(B, S, D)
|
# 哪些 token 命中了 expert i
|
||||||
else:
|
token_idx, k_idx = mask.nonzero(as_tuple=True)
|
||||||
# Top-K 通用路径(k > 1)
|
|
||||||
out = torch.zeros_like(x)
|
selected_x = x_flat[token_idx] # [N, D]
|
||||||
for i in range(self.num_experts):
|
|
||||||
mask = (idx_flat == i) # [B*S, k]
|
expert_out = self.experts[i](selected_x) # [N, D]
|
||||||
if not mask.any():
|
|
||||||
continue
|
weight = score_flat[token_idx, k_idx].unsqueeze(-1)
|
||||||
token_idx, k_idx = mask.nonzero(as_tuple=True)
|
|
||||||
selected_x = x_flat[token_idx]
|
out_flat = out.reshape(-1, D)
|
||||||
expert_out = self.experts[i](selected_x)
|
out_flat[token_idx] += expert_out * weight
|
||||||
weight = score_flat[token_idx, k_idx].unsqueeze(-1)
|
|
||||||
out_flat = out.reshape(-1, D)
|
|
||||||
out_flat[token_idx] += expert_out * weight
|
|
||||||
|
|
||||||
importance = probs.sum(dim=(0,1)) # [E]
|
importance = probs.sum(dim=(0,1)) # [E]
|
||||||
moe_loss = (importance.std() / (importance.mean() + 1e-6))
|
moe_loss = (importance.std() / (importance.mean() + 1e-6))
|
||||||
@@ -396,7 +388,7 @@ class TransformerEncoder(nn.Module):
|
|||||||
self.act = getattr(F, act)
|
self.act = getattr(F, act)
|
||||||
self.attention_fn = attention_fn
|
self.attention_fn = attention_fn
|
||||||
self.moe = nn.ModuleList([
|
self.moe = nn.ModuleList([
|
||||||
SMoE(d_model, dim_ff, num_experts=8, k=1) # Top-1 gating: 每个 token 仅激活 1 个 expert
|
SMoE(d_model, dim_ff, num_experts=8, k=2)
|
||||||
for _ in range(num_layers)
|
for _ in range(num_layers)
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user