feat: INT8 dense MoE(torch._int_mm,2D拼接W1_cat/W2_cat,top-k加权折进GEMM2,per-tensor激活量化)

dense MoE两个batched GEMM重写成2D GEMM以用A800 int8 tensor core;计算减半。
quant/dequant是真compute本地可见→本地bench即可判生死。默认关,bench --moe-int8。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-20 01:35:55 +08:00
parent 112ea014aa
commit 84db692f07
2 changed files with 44 additions and 0 deletions
+3
View File
@@ -349,6 +349,7 @@ def _parse_args():
ap.add_argument("--logit-bias", type=float, default=None, help="PCOC校准:logit偏移(本地验证PCOC→1.0)")
ap.add_argument("--moe-sparse", action="store_true", help="真稀疏MoE(只算top-k,capacity分组)")
ap.add_argument("--moe-cap", type=float, default=None, help="MoE capacity factor")
ap.add_argument("--moe-int8", action="store_true", help="INT8 dense MoE(torch._int_mm)")
ap.add_argument("--sparse-pool", action="store_true", help="稀疏矩阵乘做池化(段内高重复时省)")
ap.add_argument("--precompute-rep", action="store_true",
help="预计算RepEncoder缓存,model(batch)跳过embedding层(从batches自建)")
@@ -405,6 +406,8 @@ if __name__ == "__main__":
cfg["logit_bias"] = a.logit_bias
if a.moe_sparse:
cfg["moe_sparse"] = True
if a.moe_int8:
cfg["moe_int8"] = True
if a.moe_cap is not None:
cfg["moe_capacity"] = a.moe_cap
if a.sparse_pool:
+41
View File
@@ -157,6 +157,7 @@ CONFIG = {
"moe_fused_weight": False, # True=top-k加权用scatter+mul+sum(评测慢,勿开)
# 真稀疏MoE实测评测净负:lat 34.64->37.64s(本地快15%但argsort/scatter开销评测放大,如varlen)
# +容量丢弃降AUC(0.7525->0.7507)。已退回 dense。
"moe_int8": False, # True=INT8 dense MoE(torch._int_mm,2D拼接);计算减半但加quant kernel,有AUC风险
"moe_sparse": False, # True=真稀疏MoE(评测净负,勿开)
"moe_capacity": 2.0,
"skip_moe_loss": True, # 推理跳过 moe_loss(load-balance,推理无用),省 importance/std/mean kernel
@@ -706,8 +707,45 @@ class SMoE(nn.Module):
# baddbmm 用的转置权重([E,D,F] / [E,F,D]),预转 contiguous
self.register_buffer("W1t", self.W1.transpose(1, 2).contiguous()) # [E,D,F]
self.register_buffer("W2t", self.W2.transpose(1, 2).contiguous()) # [E,F,D]
# INT82D 拼接权重 W1_cat[D,E*F] / W2_cat[E*F,D]per-output-channel 量化)供 _int_mm
E, F, D = self.num_experts, self.W1.shape[1], self.W1.shape[2]
W1_cat = self.W1t.permute(1, 0, 2).reshape(D, E * F).float() # [D, E*F]
s1 = (W1_cat.abs().amax(0) / 127.0).clamp_min(1e-8) # [E*F]
self.register_buffer("W1_cat_i8", (W1_cat / s1).round().clamp(-127, 127).to(torch.int8).contiguous())
self.register_buffer("w1_scale", s1.to(torch.float16))
self.register_buffer("b1_cat", self.b1.reshape(E * F).to(torch.float16))
W2_cat = self.W2t.reshape(E * F, D).float() # [E*F, D]
s2 = (W2_cat.abs().amax(0) / 127.0).clamp_min(1e-8) # [D]
self.register_buffer("W2_cat_i8", (W2_cat / s2).round().clamp(-127, 127).to(torch.int8).contiguous())
self.register_buffer("w2_scale", s2.to(torch.float16))
self._stacked = True
def _forward_int8(self, x):
"""INT8 dense MoE:两个 2D GEMM 用 torch._int_mmA800 int8 tensor core),
top-k 加权折进第二个 GEMM。per-tensor 激活量化。计算减半,但 quant/dequant 加 kernel。"""
B, S, D = x.shape
topk_idx, topk_score, _ = self.gate(x)
N, E, k = B * S, self.num_experts, self.k
F = self.W1t.shape[2]
xf = x.reshape(N, D).to(torch.float16)
pad = (-N) % 16 # _int_mm 要求行数 %16
if pad:
xf = torch.cat([xf, xf.new_zeros(pad, D)], 0)
Np = xf.shape[0]
xs = (xf.abs().amax() / 127.0).clamp_min(1e-8)
xq = (xf / xs).round().clamp(-127, 127).to(torch.int8)
h = torch._int_mm(xq, self.W1_cat_i8).to(torch.float16) # [Np, E*F] int32->fp16
h = h * (xs * self.w1_scale) + self.b1_cat
h = torch.relu(h)
w = torch.zeros(Np, E, dtype=torch.float16, device=x.device)
w[:N].scatter_(1, topk_idx.reshape(-1, k), topk_score.reshape(-1, k).to(torch.float16))
hw = (h.view(Np, E, F) * w.unsqueeze(-1)).reshape(Np, E * F)
hs = (hw.abs().amax() / 127.0).clamp_min(1e-8)
hq = (hw / hs).round().clamp(-127, 127).to(torch.int8)
o = torch._int_mm(hq, self.W2_cat_i8).to(torch.float16) # [Np, D]
o = o * (hs * self.w2_scale) + w @ self.b2
return o[:N].reshape(B, S, D), o.new_zeros(())
def _forward_sparse(self, x):
"""真稀疏 MoE:每 token 只算 top-k expert(按 expert 排序 + capacity 分桶 + cutlass baddbmm)。
全程无 host 同步(argsort/where/scatter/index_add)。超容量 token 被丢弃(capacity_factor 控)。"""
@@ -749,6 +787,9 @@ class SMoE(nn.Module):
if not self._stacked:
self._stack_weights()
if CONFIG.get("moe_int8", False):
return self._forward_int8(x)
if CONFIG.get("moe_sparse", False):
return self._forward_sparse(x)