fix: INT8 MoE int32结果先转fp32反量化再fp16(直接.half()溢出830万>65504致NaN)
This commit is contained in:
+7
-6
@@ -734,17 +734,18 @@ class SMoE(nn.Module):
|
|||||||
Np = xf.shape[0]
|
Np = xf.shape[0]
|
||||||
xs = (xf.abs().amax() / 127.0).clamp_min(1e-8)
|
xs = (xf.abs().amax() / 127.0).clamp_min(1e-8)
|
||||||
xq = (xf / xs).round().clamp(-127, 127).to(torch.int8)
|
xq = (xf / xs).round().clamp(-127, 127).to(torch.int8)
|
||||||
h = torch._int_mm(xq, self.W1_cat_i8).to(torch.float16) # [Np, E*F] int32->fp16
|
# int32 结果可达 ~830万,超 fp16 上限 → 先转 fp32 反量化(×小 scale 拉回),再 fp16
|
||||||
h = h * (xs * self.w1_scale) + self.b1_cat
|
h = torch._int_mm(xq, self.W1_cat_i8).to(torch.float32) # [Np, E*F]
|
||||||
h = torch.relu(h)
|
h = h * (xs.float() * self.w1_scale.float())
|
||||||
|
h = torch.relu(h + self.b1_cat.float()).to(torch.float16)
|
||||||
w = torch.zeros(Np, E, dtype=torch.float16, device=x.device)
|
w = torch.zeros(Np, E, dtype=torch.float16, device=x.device)
|
||||||
w[:N].scatter_(1, topk_idx.reshape(-1, k), topk_score.reshape(-1, k).to(torch.float16))
|
w[:N].scatter_(1, topk_idx.reshape(-1, k), topk_score.reshape(-1, k).to(torch.float16))
|
||||||
hw = (h.view(Np, E, F) * w.unsqueeze(-1)).reshape(Np, E * F)
|
hw = (h.view(Np, E, F) * w.unsqueeze(-1)).reshape(Np, E * F)
|
||||||
hs = (hw.abs().amax() / 127.0).clamp_min(1e-8)
|
hs = (hw.abs().amax() / 127.0).clamp_min(1e-8)
|
||||||
hq = (hw / hs).round().clamp(-127, 127).to(torch.int8)
|
hq = (hw / hs).round().clamp(-127, 127).to(torch.int8)
|
||||||
o = torch._int_mm(hq, self.W2_cat_i8).to(torch.float16) # [Np, D]
|
o = torch._int_mm(hq, self.W2_cat_i8).to(torch.float32) # [Np, D]
|
||||||
o = o * (hs * self.w2_scale) + w @ self.b2
|
o = o * (hs.float() * self.w2_scale.float()) + (w @ self.b2).float()
|
||||||
return o[:N].reshape(B, S, D), o.new_zeros(())
|
return o[:N].reshape(B, S, D).to(torch.float16), o.new_zeros(())
|
||||||
|
|
||||||
def _forward_sparse(self, x):
|
def _forward_sparse(self, x):
|
||||||
"""真稀疏 MoE:每 token 只算 top-k expert(按 expert 排序 + capacity 分桶 + cutlass baddbmm)。
|
"""真稀疏 MoE:每 token 只算 top-k expert(按 expert 排序 + capacity 分桶 + cutlass baddbmm)。
|
||||||
|
|||||||
Reference in New Issue
Block a user