From 96462444f662b39d44d0b13a46c1800744661707 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E8=88=AA=E5=AE=87?= <3364451258@qq.com> Date: Sat, 13 Jun 2026 13:53:45 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20INT8=20=E5=8A=A8=E6=80=81=E9=87=8F?= =?UTF-8?q?=E5=8C=96=E6=89=80=E6=9C=89=20Linear=20=E5=B1=82=EF=BC=88torch.?= =?UTF-8?q?ao.quantization=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 仅量化 Linear 权重(不影响 Embedding) - INT8 权重读带宽减半 vs FP16 - try-except 保护:CUDA 后端不可用时回退 FP16 --- 代码/code/infer.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/代码/code/infer.py b/代码/code/infer.py index a4ac157..3b13cd5 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -500,6 +500,23 @@ def load_model(ckpt_path, device='cuda:0'): model = model.half() model.rep_encoder.emb = model.rep_encoder.emb.to(torch.float32) print("[INFO] Model converted to FP16 (embedding kept in FP32)") + + # === INT8 动态量化:所有 Linear 层权重 INT8,matmul 2x 加速 === + try: + from torch.ao.quantization import quantize_dynamic + # 排除 embedding 层,仅量化 Linear + model.seq_encoder = quantize_dynamic( + model.seq_encoder, {nn.Linear}, dtype=torch.qint8 + ) + model.linear = quantize_dynamic( + model.linear, {nn.Linear}, dtype=torch.qint8 + ) + model.rep_encoder.linear = quantize_dynamic( + model.rep_encoder.linear, {nn.Linear}, dtype=torch.qint8 + ) + print("[INFO] INT8 dynamic quantization applied to Linear layers") + except Exception as e: + print(f"[WARNING] INT8 quantization failed ({e}), keeping FP16") else: print(f"[WARNING] Checkpoint {ckpt_path} not found, using random weights")