From da37245a9bfa86fa69ca45bff810fe065f0f3391 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E8=88=AA=E5=AE=87?= <3364451258@qq.com> Date: Sat, 13 Jun 2026 13:16:01 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20SMoE=20=E6=B6=88=E9=99=A4=20GPU=20?= =?UTF-8?q?=E5=90=8C=E6=AD=A5=20+=20CTRModel=20=E5=8E=BB=E5=86=97=E4=BD=99?= =?UTF-8?q?=20reshape?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. SMoE: 移除 if not mask.any()(64次GPU→CPU同步/forward) - k=2时每个expert都分到token,检查从不跳过 - 改用 token_idx.numel()==0 判断(元数据操作,不同步) 2. SMoE: out_flat reshape 提到循环外(省7次重复) 3. CTRModel: encoder_output.reshape().squeeze() → .squeeze() --- 代码/code/infer.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/代码/code/infer.py b/代码/code/infer.py index d2109c5..bdf89ae 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -343,24 +343,20 @@ class SMoE(nn.Module): x_flat = x.reshape(-1, D) # [B*S, D] idx_flat = topk_idx.reshape(-1, self.k) # [B*S, k] score_flat = topk_score.reshape(-1, self.k) + out_flat = out.reshape(-1, D) # 提前 reshape,避免循环内重复 for i in range(self.num_experts): # 找到被路由到 expert i 的 token mask = (idx_flat == i) # [B*S, k] + # 注:k=2 时几乎所有 expert 都分到 token,移除 .any() 检查避免 GPU 同步 - if not mask.any(): + token_idx, k_idx = mask.nonzero(as_tuple=True) + if token_idx.numel() == 0: continue - # 哪些 token 命中了 expert i - token_idx, k_idx = mask.nonzero(as_tuple=True) - selected_x = x_flat[token_idx] # [N, D] - expert_out = self.experts[i](selected_x) # [N, D] - weight = score_flat[token_idx, k_idx].unsqueeze(-1) - - out_flat = out.reshape(-1, D) out_flat[token_idx] += expert_out * weight importance = probs.sum(dim=(0,1)) # [E] @@ -443,8 +439,7 @@ class CTRModel(nn.Module): x=seq_input, extension={"mask": seq_mask.unsqueeze(0).unsqueeze(0)}, ) - encoder_output_dim = encoder_output.shape[-1] - encoder_output = encoder_output.reshape(1, -1, encoder_output_dim).squeeze(0) + encoder_output = encoder_output.squeeze(0) pred = self.linear(encoder_output) pred_logits = torch.clamp(pred, min=-15.0, max=15.0) return pred_logits, moe_loss