diff --git a/代码/code/infer.py b/代码/code/infer.py index d2109c5..bdf89ae 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -343,24 +343,20 @@ class SMoE(nn.Module): x_flat = x.reshape(-1, D) # [B*S, D] idx_flat = topk_idx.reshape(-1, self.k) # [B*S, k] score_flat = topk_score.reshape(-1, self.k) + out_flat = out.reshape(-1, D) # 提前 reshape,避免循环内重复 for i in range(self.num_experts): # 找到被路由到 expert i 的 token mask = (idx_flat == i) # [B*S, k] + # 注:k=2 时几乎所有 expert 都分到 token,移除 .any() 检查避免 GPU 同步 - if not mask.any(): + token_idx, k_idx = mask.nonzero(as_tuple=True) + if token_idx.numel() == 0: continue - # 哪些 token 命中了 expert i - token_idx, k_idx = mask.nonzero(as_tuple=True) - selected_x = x_flat[token_idx] # [N, D] - expert_out = self.experts[i](selected_x) # [N, D] - weight = score_flat[token_idx, k_idx].unsqueeze(-1) - - out_flat = out.reshape(-1, D) out_flat[token_idx] += expert_out * weight importance = probs.sum(dim=(0,1)) # [E] @@ -443,8 +439,7 @@ class CTRModel(nn.Module): x=seq_input, extension={"mask": seq_mask.unsqueeze(0).unsqueeze(0)}, ) - encoder_output_dim = encoder_output.shape[-1] - encoder_output = encoder_output.reshape(1, -1, encoder_output_dim).squeeze(0) + encoder_output = encoder_output.squeeze(0) pred = self.linear(encoder_output) pred_logits = torch.clamp(pred, min=-15.0, max=15.0) return pred_logits, moe_loss