perf: SMoE 消除 GPU 同步 + CTRModel 去冗余 reshape
1. SMoE: 移除 if not mask.any()(64次GPU→CPU同步/forward) - k=2时每个expert都分到token,检查从不跳过 - 改用 token_idx.numel()==0 判断(元数据操作,不同步) 2. SMoE: out_flat reshape 提到循环外(省7次重复) 3. CTRModel: encoder_output.reshape().squeeze() → .squeeze()
This commit is contained in:
+5
-10
@@ -343,24 +343,20 @@ class SMoE(nn.Module):
|
|||||||
x_flat = x.reshape(-1, D) # [B*S, D]
|
x_flat = x.reshape(-1, D) # [B*S, D]
|
||||||
idx_flat = topk_idx.reshape(-1, self.k) # [B*S, k]
|
idx_flat = topk_idx.reshape(-1, self.k) # [B*S, k]
|
||||||
score_flat = topk_score.reshape(-1, self.k)
|
score_flat = topk_score.reshape(-1, self.k)
|
||||||
|
out_flat = out.reshape(-1, D) # 提前 reshape,避免循环内重复
|
||||||
|
|
||||||
for i in range(self.num_experts):
|
for i in range(self.num_experts):
|
||||||
# 找到被路由到 expert i 的 token
|
# 找到被路由到 expert i 的 token
|
||||||
mask = (idx_flat == i) # [B*S, k]
|
mask = (idx_flat == i) # [B*S, k]
|
||||||
|
# 注:k=2 时几乎所有 expert 都分到 token,移除 .any() 检查避免 GPU 同步
|
||||||
|
|
||||||
if not mask.any():
|
token_idx, k_idx = mask.nonzero(as_tuple=True)
|
||||||
|
if token_idx.numel() == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 哪些 token 命中了 expert i
|
|
||||||
token_idx, k_idx = mask.nonzero(as_tuple=True)
|
|
||||||
|
|
||||||
selected_x = x_flat[token_idx] # [N, D]
|
selected_x = x_flat[token_idx] # [N, D]
|
||||||
|
|
||||||
expert_out = self.experts[i](selected_x) # [N, D]
|
expert_out = self.experts[i](selected_x) # [N, D]
|
||||||
|
|
||||||
weight = score_flat[token_idx, k_idx].unsqueeze(-1)
|
weight = score_flat[token_idx, k_idx].unsqueeze(-1)
|
||||||
|
|
||||||
out_flat = out.reshape(-1, D)
|
|
||||||
out_flat[token_idx] += expert_out * weight
|
out_flat[token_idx] += expert_out * weight
|
||||||
|
|
||||||
importance = probs.sum(dim=(0,1)) # [E]
|
importance = probs.sum(dim=(0,1)) # [E]
|
||||||
@@ -443,8 +439,7 @@ class CTRModel(nn.Module):
|
|||||||
x=seq_input,
|
x=seq_input,
|
||||||
extension={"mask": seq_mask.unsqueeze(0).unsqueeze(0)},
|
extension={"mask": seq_mask.unsqueeze(0).unsqueeze(0)},
|
||||||
)
|
)
|
||||||
encoder_output_dim = encoder_output.shape[-1]
|
encoder_output = encoder_output.squeeze(0)
|
||||||
encoder_output = encoder_output.reshape(1, -1, encoder_output_dim).squeeze(0)
|
|
||||||
pred = self.linear(encoder_output)
|
pred = self.linear(encoder_output)
|
||||||
pred_logits = torch.clamp(pred, min=-15.0, max=15.0)
|
pred_logits = torch.clamp(pred, min=-15.0, max=15.0)
|
||||||
return pred_logits, moe_loss
|
return pred_logits, moe_loss
|
||||||
|
|||||||
Reference in New Issue
Block a user