From 3c9da9a47df671bace215c55fa99efcf4ab14e6d Mon Sep 17 00:00:00 2001 From: OwnerSunshine530 Date: Sat, 20 Jun 2026 01:45:05 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20INT8=20MoE=20int32=E7=BB=93=E6=9E=9C?= =?UTF-8?q?=E5=85=88=E8=BD=ACfp32=E5=8F=8D=E9=87=8F=E5=8C=96=E5=86=8Dfp16(?= =?UTF-8?q?=E7=9B=B4=E6=8E=A5.half()=E6=BA=A2=E5=87=BA830=E4=B8=87>65504?= =?UTF-8?q?=E8=87=B4NaN)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- 代码/code/infer.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/代码/code/infer.py b/代码/code/infer.py index c71ed22..36d7153 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -734,17 +734,18 @@ class SMoE(nn.Module): Np = xf.shape[0] xs = (xf.abs().amax() / 127.0).clamp_min(1e-8) xq = (xf / xs).round().clamp(-127, 127).to(torch.int8) - h = torch._int_mm(xq, self.W1_cat_i8).to(torch.float16) # [Np, E*F] int32->fp16 - h = h * (xs * self.w1_scale) + self.b1_cat - h = torch.relu(h) + # int32 结果可达 ~830万,超 fp16 上限 → 先转 fp32 反量化(×小 scale 拉回),再 fp16 + h = torch._int_mm(xq, self.W1_cat_i8).to(torch.float32) # [Np, E*F] + h = h * (xs.float() * self.w1_scale.float()) + h = torch.relu(h + self.b1_cat.float()).to(torch.float16) w = torch.zeros(Np, E, dtype=torch.float16, device=x.device) w[:N].scatter_(1, topk_idx.reshape(-1, k), topk_score.reshape(-1, k).to(torch.float16)) hw = (h.view(Np, E, F) * w.unsqueeze(-1)).reshape(Np, E * F) hs = (hw.abs().amax() / 127.0).clamp_min(1e-8) hq = (hw / hs).round().clamp(-127, 127).to(torch.int8) - o = torch._int_mm(hq, self.W2_cat_i8).to(torch.float16) # [Np, D] - o = o * (hs * self.w2_scale) + w @ self.b2 - return o[:N].reshape(B, S, D), o.new_zeros(()) + o = torch._int_mm(hq, self.W2_cat_i8).to(torch.float32) # [Np, D] + o = o * (hs.float() * self.w2_scale.float()) + (w @ self.b2).float() + return o[:N].reshape(B, S, D).to(torch.float16), o.new_zeros(()) def _forward_sparse(self, x): """真稀疏 MoE:每 token 只算 top-k expert(按 expert 排序 + capacity 分桶 + cutlass baddbmm)。