diff --git a/代码/code/infer.py b/代码/code/infer.py index c71ed22..36d7153 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -734,17 +734,18 @@ class SMoE(nn.Module): Np = xf.shape[0] xs = (xf.abs().amax() / 127.0).clamp_min(1e-8) xq = (xf / xs).round().clamp(-127, 127).to(torch.int8) - h = torch._int_mm(xq, self.W1_cat_i8).to(torch.float16) # [Np, E*F] int32->fp16 - h = h * (xs * self.w1_scale) + self.b1_cat - h = torch.relu(h) + # int32 结果可达 ~830万,超 fp16 上限 → 先转 fp32 反量化(×小 scale 拉回),再 fp16 + h = torch._int_mm(xq, self.W1_cat_i8).to(torch.float32) # [Np, E*F] + h = h * (xs.float() * self.w1_scale.float()) + h = torch.relu(h + self.b1_cat.float()).to(torch.float16) w = torch.zeros(Np, E, dtype=torch.float16, device=x.device) w[:N].scatter_(1, topk_idx.reshape(-1, k), topk_score.reshape(-1, k).to(torch.float16)) hw = (h.view(Np, E, F) * w.unsqueeze(-1)).reshape(Np, E * F) hs = (hw.abs().amax() / 127.0).clamp_min(1e-8) hq = (hw / hs).round().clamp(-127, 127).to(torch.int8) - o = torch._int_mm(hq, self.W2_cat_i8).to(torch.float16) # [Np, D] - o = o * (hs * self.w2_scale) + w @ self.b2 - return o[:N].reshape(B, S, D), o.new_zeros(()) + o = torch._int_mm(hq, self.W2_cat_i8).to(torch.float32) # [Np, D] + o = o * (hs.float() * self.w2_scale.float()) + (w @ self.b2).float() + return o[:N].reshape(B, S, D).to(torch.float16), o.new_zeros(()) def _forward_sparse(self, x): """真稀疏 MoE:每 token 只算 top-k expert(按 expert 排序 + capacity 分桶 + cutlass baddbmm)。