diff --git a/代码/code/infer.py b/代码/code/infer.py index a4ac157..3b13cd5 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -500,6 +500,23 @@ def load_model(ckpt_path, device='cuda:0'): model = model.half() model.rep_encoder.emb = model.rep_encoder.emb.to(torch.float32) print("[INFO] Model converted to FP16 (embedding kept in FP32)") + + # === INT8 动态量化:所有 Linear 层权重 INT8,matmul 2x 加速 === + try: + from torch.ao.quantization import quantize_dynamic + # 排除 embedding 层,仅量化 Linear + model.seq_encoder = quantize_dynamic( + model.seq_encoder, {nn.Linear}, dtype=torch.qint8 + ) + model.linear = quantize_dynamic( + model.linear, {nn.Linear}, dtype=torch.qint8 + ) + model.rep_encoder.linear = quantize_dynamic( + model.rep_encoder.linear, {nn.Linear}, dtype=torch.qint8 + ) + print("[INFO] INT8 dynamic quantization applied to Linear layers") + except Exception as e: + print(f"[WARNING] INT8 quantization failed ({e}), keeping FP16") else: print(f"[WARNING] Checkpoint {ckpt_path} not found, using random weights")