feat: INT8 动态量化所有 Linear 层(torch.ao.quantization)

- 仅量化 Linear 权重(不影响 Embedding)
- INT8 权重读带宽减半 vs FP16
- try-except 保护:CUDA 后端不可用时回退 FP16
This commit is contained in:
2026-06-13 13:53:45 +08:00
parent c081620ffd
commit 96462444f6
+17
View File
@@ -500,6 +500,23 @@ def load_model(ckpt_path, device='cuda:0'):
model = model.half() model = model.half()
model.rep_encoder.emb = model.rep_encoder.emb.to(torch.float32) model.rep_encoder.emb = model.rep_encoder.emb.to(torch.float32)
print("[INFO] Model converted to FP16 (embedding kept in FP32)") print("[INFO] Model converted to FP16 (embedding kept in FP32)")
# === INT8 动态量化:所有 Linear 层权重 INT8matmul 2x 加速 ===
try:
from torch.ao.quantization import quantize_dynamic
# 排除 embedding 层,仅量化 Linear
model.seq_encoder = quantize_dynamic(
model.seq_encoder, {nn.Linear}, dtype=torch.qint8
)
model.linear = quantize_dynamic(
model.linear, {nn.Linear}, dtype=torch.qint8
)
model.rep_encoder.linear = quantize_dynamic(
model.rep_encoder.linear, {nn.Linear}, dtype=torch.qint8
)
print("[INFO] INT8 dynamic quantization applied to Linear layers")
except Exception as e:
print(f"[WARNING] INT8 quantization failed ({e}), keeping FP16")
else: else:
print(f"[WARNING] Checkpoint {ckpt_path} not found, using random weights") print(f"[WARNING] Checkpoint {ckpt_path} not found, using random weights")