diff --git a/代码/code/bench.py b/代码/code/bench.py index a800af4..6c5fdbe 100644 --- a/代码/code/bench.py +++ b/代码/code/bench.py @@ -349,6 +349,7 @@ def _parse_args(): ap.add_argument("--logit-bias", type=float, default=None, help="PCOC校准:logit偏移(本地验证PCOC→1.0)") ap.add_argument("--moe-sparse", action="store_true", help="真稀疏MoE(只算top-k,capacity分组)") ap.add_argument("--moe-cap", type=float, default=None, help="MoE capacity factor") + ap.add_argument("--moe-int8", action="store_true", help="INT8 dense MoE(torch._int_mm)") ap.add_argument("--sparse-pool", action="store_true", help="稀疏矩阵乘做池化(段内高重复时省)") ap.add_argument("--precompute-rep", action="store_true", help="预计算RepEncoder缓存,model(batch)跳过embedding层(从batches自建)") @@ -405,6 +406,8 @@ if __name__ == "__main__": cfg["logit_bias"] = a.logit_bias if a.moe_sparse: cfg["moe_sparse"] = True + if a.moe_int8: + cfg["moe_int8"] = True if a.moe_cap is not None: cfg["moe_capacity"] = a.moe_cap if a.sparse_pool: diff --git a/代码/code/infer.py b/代码/code/infer.py index 4e44ba8..c71ed22 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -157,6 +157,7 @@ CONFIG = { "moe_fused_weight": False, # True=top-k加权用scatter+mul+sum(评测慢,勿开) # 真稀疏MoE实测评测净负:lat 34.64->37.64s(本地快15%但argsort/scatter开销评测放大,如varlen) # +容量丢弃降AUC(0.7525->0.7507)。已退回 dense。 + "moe_int8": False, # True=INT8 dense MoE(torch._int_mm,2D拼接);计算减半但加quant kernel,有AUC风险 "moe_sparse": False, # True=真稀疏MoE(评测净负,勿开) "moe_capacity": 2.0, "skip_moe_loss": True, # 推理跳过 moe_loss(load-balance,推理无用),省 importance/std/mean kernel @@ -706,8 +707,45 @@ class SMoE(nn.Module): # baddbmm 用的转置权重([E,D,F] / [E,F,D]),预转 contiguous self.register_buffer("W1t", self.W1.transpose(1, 2).contiguous()) # [E,D,F] self.register_buffer("W2t", self.W2.transpose(1, 2).contiguous()) # [E,F,D] + # INT8:2D 拼接权重 W1_cat[D,E*F] / W2_cat[E*F,D](per-output-channel 量化)供 _int_mm + E, F, D = self.num_experts, self.W1.shape[1], self.W1.shape[2] + W1_cat = self.W1t.permute(1, 0, 2).reshape(D, E * F).float() # [D, E*F] + s1 = (W1_cat.abs().amax(0) / 127.0).clamp_min(1e-8) # [E*F] + self.register_buffer("W1_cat_i8", (W1_cat / s1).round().clamp(-127, 127).to(torch.int8).contiguous()) + self.register_buffer("w1_scale", s1.to(torch.float16)) + self.register_buffer("b1_cat", self.b1.reshape(E * F).to(torch.float16)) + W2_cat = self.W2t.reshape(E * F, D).float() # [E*F, D] + s2 = (W2_cat.abs().amax(0) / 127.0).clamp_min(1e-8) # [D] + self.register_buffer("W2_cat_i8", (W2_cat / s2).round().clamp(-127, 127).to(torch.int8).contiguous()) + self.register_buffer("w2_scale", s2.to(torch.float16)) self._stacked = True + def _forward_int8(self, x): + """INT8 dense MoE:两个 2D GEMM 用 torch._int_mm(A800 int8 tensor core), + top-k 加权折进第二个 GEMM。per-tensor 激活量化。计算减半,但 quant/dequant 加 kernel。""" + B, S, D = x.shape + topk_idx, topk_score, _ = self.gate(x) + N, E, k = B * S, self.num_experts, self.k + F = self.W1t.shape[2] + xf = x.reshape(N, D).to(torch.float16) + pad = (-N) % 16 # _int_mm 要求行数 %16 + if pad: + xf = torch.cat([xf, xf.new_zeros(pad, D)], 0) + Np = xf.shape[0] + xs = (xf.abs().amax() / 127.0).clamp_min(1e-8) + xq = (xf / xs).round().clamp(-127, 127).to(torch.int8) + h = torch._int_mm(xq, self.W1_cat_i8).to(torch.float16) # [Np, E*F] int32->fp16 + h = h * (xs * self.w1_scale) + self.b1_cat + h = torch.relu(h) + w = torch.zeros(Np, E, dtype=torch.float16, device=x.device) + w[:N].scatter_(1, topk_idx.reshape(-1, k), topk_score.reshape(-1, k).to(torch.float16)) + hw = (h.view(Np, E, F) * w.unsqueeze(-1)).reshape(Np, E * F) + hs = (hw.abs().amax() / 127.0).clamp_min(1e-8) + hq = (hw / hs).round().clamp(-127, 127).to(torch.int8) + o = torch._int_mm(hq, self.W2_cat_i8).to(torch.float16) # [Np, D] + o = o * (hs * self.w2_scale) + w @ self.b2 + return o[:N].reshape(B, S, D), o.new_zeros(()) + def _forward_sparse(self, x): """真稀疏 MoE:每 token 只算 top-k expert(按 expert 排序 + capacity 分桶 + cutlass baddbmm)。 全程无 host 同步(argsort/where/scatter/index_add)。超容量 token 被丢弃(capacity_factor 控)。""" @@ -749,6 +787,9 @@ class SMoE(nn.Module): if not self._stacked: self._stack_weights() + if CONFIG.get("moe_int8", False): + return self._forward_int8(x) + if CONFIG.get("moe_sparse", False): return self._forward_sparse(x)