perf: SMoE 消除 GPU 同步 + CTRModel 去冗余 reshape
1. SMoE: 移除 if not mask.any()(64次GPU→CPU同步/forward) - k=2时每个expert都分到token,检查从不跳过 - 改用 token_idx.numel()==0 判断(元数据操作,不同步) 2. SMoE: out_flat reshape 提到循环外(省7次重复) 3. CTRModel: encoder_output.reshape().squeeze() → .squeeze()
This commit is contained in:
+5
-10
@@ -343,24 +343,20 @@ class SMoE(nn.Module):
|
||||
x_flat = x.reshape(-1, D) # [B*S, D]
|
||||
idx_flat = topk_idx.reshape(-1, self.k) # [B*S, k]
|
||||
score_flat = topk_score.reshape(-1, self.k)
|
||||
out_flat = out.reshape(-1, D) # 提前 reshape,避免循环内重复
|
||||
|
||||
for i in range(self.num_experts):
|
||||
# 找到被路由到 expert i 的 token
|
||||
mask = (idx_flat == i) # [B*S, k]
|
||||
# 注:k=2 时几乎所有 expert 都分到 token,移除 .any() 检查避免 GPU 同步
|
||||
|
||||
if not mask.any():
|
||||
token_idx, k_idx = mask.nonzero(as_tuple=True)
|
||||
if token_idx.numel() == 0:
|
||||
continue
|
||||
|
||||
# 哪些 token 命中了 expert i
|
||||
token_idx, k_idx = mask.nonzero(as_tuple=True)
|
||||
|
||||
selected_x = x_flat[token_idx] # [N, D]
|
||||
|
||||
expert_out = self.experts[i](selected_x) # [N, D]
|
||||
|
||||
weight = score_flat[token_idx, k_idx].unsqueeze(-1)
|
||||
|
||||
out_flat = out.reshape(-1, D)
|
||||
out_flat[token_idx] += expert_out * weight
|
||||
|
||||
importance = probs.sum(dim=(0,1)) # [E]
|
||||
@@ -443,8 +439,7 @@ class CTRModel(nn.Module):
|
||||
x=seq_input,
|
||||
extension={"mask": seq_mask.unsqueeze(0).unsqueeze(0)},
|
||||
)
|
||||
encoder_output_dim = encoder_output.shape[-1]
|
||||
encoder_output = encoder_output.reshape(1, -1, encoder_output_dim).squeeze(0)
|
||||
encoder_output = encoder_output.squeeze(0)
|
||||
pred = self.linear(encoder_output)
|
||||
pred_logits = torch.clamp(pred, min=-15.0, max=15.0)
|
||||
return pred_logits, moe_loss
|
||||
|
||||
Reference in New Issue
Block a user