feat: INT8 动态量化所有 Linear 层(torch.ao.quantization)
- 仅量化 Linear 权重(不影响 Embedding) - INT8 权重读带宽减半 vs FP16 - try-except 保护:CUDA 后端不可用时回退 FP16
This commit is contained in:
@@ -500,6 +500,23 @@ def load_model(ckpt_path, device='cuda:0'):
|
|||||||
model = model.half()
|
model = model.half()
|
||||||
model.rep_encoder.emb = model.rep_encoder.emb.to(torch.float32)
|
model.rep_encoder.emb = model.rep_encoder.emb.to(torch.float32)
|
||||||
print("[INFO] Model converted to FP16 (embedding kept in FP32)")
|
print("[INFO] Model converted to FP16 (embedding kept in FP32)")
|
||||||
|
|
||||||
|
# === INT8 动态量化:所有 Linear 层权重 INT8,matmul 2x 加速 ===
|
||||||
|
try:
|
||||||
|
from torch.ao.quantization import quantize_dynamic
|
||||||
|
# 排除 embedding 层,仅量化 Linear
|
||||||
|
model.seq_encoder = quantize_dynamic(
|
||||||
|
model.seq_encoder, {nn.Linear}, dtype=torch.qint8
|
||||||
|
)
|
||||||
|
model.linear = quantize_dynamic(
|
||||||
|
model.linear, {nn.Linear}, dtype=torch.qint8
|
||||||
|
)
|
||||||
|
model.rep_encoder.linear = quantize_dynamic(
|
||||||
|
model.rep_encoder.linear, {nn.Linear}, dtype=torch.qint8
|
||||||
|
)
|
||||||
|
print("[INFO] INT8 dynamic quantization applied to Linear layers")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[WARNING] INT8 quantization failed ({e}), keeping FP16")
|
||||||
else:
|
else:
|
||||||
print(f"[WARNING] Checkpoint {ckpt_path} not found, using random weights")
|
print(f"[WARNING] Checkpoint {ckpt_path} not found, using random weights")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user