feat: INT8 dense MoE(torch._int_mm,2D拼接W1_cat/W2_cat,top-k加权折进GEMM2,per-tensor激活量化)

dense MoE两个batched GEMM重写成2D GEMM以用A800 int8 tensor core;计算减半。
quant/dequant是真compute本地可见→本地bench即可判生死。默认关,bench --moe-int8。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-20 01:35:55 +08:00
parent 112ea014aa
commit 84db692f07
2 changed files with 44 additions and 0 deletions
+3
View File
@@ -349,6 +349,7 @@ def _parse_args():
ap.add_argument("--logit-bias", type=float, default=None, help="PCOC校准:logit偏移(本地验证PCOC→1.0)")
ap.add_argument("--moe-sparse", action="store_true", help="真稀疏MoE(只算top-k,capacity分组)")
ap.add_argument("--moe-cap", type=float, default=None, help="MoE capacity factor")
ap.add_argument("--moe-int8", action="store_true", help="INT8 dense MoE(torch._int_mm)")
ap.add_argument("--sparse-pool", action="store_true", help="稀疏矩阵乘做池化(段内高重复时省)")
ap.add_argument("--precompute-rep", action="store_true",
help="预计算RepEncoder缓存,model(batch)跳过embedding层(从batches自建)")
@@ -405,6 +406,8 @@ if __name__ == "__main__":
cfg["logit_bias"] = a.logit_bias
if a.moe_sparse:
cfg["moe_sparse"] = True
if a.moe_int8:
cfg["moe_int8"] = True
if a.moe_cap is not None:
cfg["moe_capacity"] = a.moe_cap
if a.sparse_pool: