fix: INT8 MoE int32结果先转fp32反量化再fp16(直接.half()溢出830万>65504致NaN)

This commit is contained in:
OwnerSunshine530
2026-06-20 01:45:05 +08:00
parent 84db692f07
commit 3c9da9a47d
+7 -6
View File
@@ -734,17 +734,18 @@ class SMoE(nn.Module):
Np = xf.shape[0] Np = xf.shape[0]
xs = (xf.abs().amax() / 127.0).clamp_min(1e-8) xs = (xf.abs().amax() / 127.0).clamp_min(1e-8)
xq = (xf / xs).round().clamp(-127, 127).to(torch.int8) xq = (xf / xs).round().clamp(-127, 127).to(torch.int8)
h = torch._int_mm(xq, self.W1_cat_i8).to(torch.float16) # [Np, E*F] int32->fp16 # int32 结果可达 ~830万,超 fp16 上限 → 先转 fp32 反量化(×小 scale 拉回),再 fp16
h = h * (xs * self.w1_scale) + self.b1_cat h = torch._int_mm(xq, self.W1_cat_i8).to(torch.float32) # [Np, E*F]
h = torch.relu(h) h = h * (xs.float() * self.w1_scale.float())
h = torch.relu(h + self.b1_cat.float()).to(torch.float16)
w = torch.zeros(Np, E, dtype=torch.float16, device=x.device) w = torch.zeros(Np, E, dtype=torch.float16, device=x.device)
w[:N].scatter_(1, topk_idx.reshape(-1, k), topk_score.reshape(-1, k).to(torch.float16)) w[:N].scatter_(1, topk_idx.reshape(-1, k), topk_score.reshape(-1, k).to(torch.float16))
hw = (h.view(Np, E, F) * w.unsqueeze(-1)).reshape(Np, E * F) hw = (h.view(Np, E, F) * w.unsqueeze(-1)).reshape(Np, E * F)
hs = (hw.abs().amax() / 127.0).clamp_min(1e-8) hs = (hw.abs().amax() / 127.0).clamp_min(1e-8)
hq = (hw / hs).round().clamp(-127, 127).to(torch.int8) hq = (hw / hs).round().clamp(-127, 127).to(torch.int8)
o = torch._int_mm(hq, self.W2_cat_i8).to(torch.float16) # [Np, D] o = torch._int_mm(hq, self.W2_cat_i8).to(torch.float32) # [Np, D]
o = o * (hs * self.w2_scale) + w @ self.b2 o = o * (hs.float() * self.w2_scale.float()) + (w @ self.b2).float()
return o[:N].reshape(B, S, D), o.new_zeros(()) return o[:N].reshape(B, S, D).to(torch.float16), o.new_zeros(())
def _forward_sparse(self, x): def _forward_sparse(self, x):
"""真稀疏 MoE:每 token 只算 top-k expert(按 expert 排序 + capacity 分桶 + cutlass baddbmm)。 """真稀疏 MoE:每 token 只算 top-k expert(按 expert 排序 + capacity 分桶 + cutlass baddbmm)。