perf: SMoE 消除 GPU 同步 + CTRModel 去冗余 reshape

1. SMoE: 移除 if not mask.any()(64次GPU→CPU同步/forward)
   - k=2时每个expert都分到token,检查从不跳过
   - 改用 token_idx.numel()==0 判断(元数据操作,不同步)
2. SMoE: out_flat reshape 提到循环外(省7次重复)
3. CTRModel: encoder_output.reshape().squeeze() → .squeeze()
This commit is contained in:
2026-06-13 13:16:01 +08:00
parent 7e0876c671
commit da37245a9b
+5 -10
View File
@@ -343,24 +343,20 @@ class SMoE(nn.Module):
x_flat = x.reshape(-1, D) # [B*S, D] x_flat = x.reshape(-1, D) # [B*S, D]
idx_flat = topk_idx.reshape(-1, self.k) # [B*S, k] idx_flat = topk_idx.reshape(-1, self.k) # [B*S, k]
score_flat = topk_score.reshape(-1, self.k) score_flat = topk_score.reshape(-1, self.k)
out_flat = out.reshape(-1, D) # 提前 reshape,避免循环内重复
for i in range(self.num_experts): for i in range(self.num_experts):
# 找到被路由到 expert i 的 token # 找到被路由到 expert i 的 token
mask = (idx_flat == i) # [B*S, k] mask = (idx_flat == i) # [B*S, k]
# 注:k=2 时几乎所有 expert 都分到 token,移除 .any() 检查避免 GPU 同步
if not mask.any(): token_idx, k_idx = mask.nonzero(as_tuple=True)
if token_idx.numel() == 0:
continue continue
# 哪些 token 命中了 expert i
token_idx, k_idx = mask.nonzero(as_tuple=True)
selected_x = x_flat[token_idx] # [N, D] selected_x = x_flat[token_idx] # [N, D]
expert_out = self.experts[i](selected_x) # [N, D] expert_out = self.experts[i](selected_x) # [N, D]
weight = score_flat[token_idx, k_idx].unsqueeze(-1) weight = score_flat[token_idx, k_idx].unsqueeze(-1)
out_flat = out.reshape(-1, D)
out_flat[token_idx] += expert_out * weight out_flat[token_idx] += expert_out * weight
importance = probs.sum(dim=(0,1)) # [E] importance = probs.sum(dim=(0,1)) # [E]
@@ -443,8 +439,7 @@ class CTRModel(nn.Module):
x=seq_input, x=seq_input,
extension={"mask": seq_mask.unsqueeze(0).unsqueeze(0)}, extension={"mask": seq_mask.unsqueeze(0).unsqueeze(0)},
) )
encoder_output_dim = encoder_output.shape[-1] encoder_output = encoder_output.squeeze(0)
encoder_output = encoder_output.reshape(1, -1, encoder_output_dim).squeeze(0)
pred = self.linear(encoder_output) pred = self.linear(encoder_output)
pred_logits = torch.clamp(pred, min=-15.0, max=15.0) pred_logits = torch.clamp(pred, min=-15.0, max=15.0)
return pred_logits, moe_loss return pred_logits, moe_loss