feat: emb_fp16 选项(Embedding表转FP16,查表带宽减半);bench --emb-fp16

embedding查表是显存带宽瓶颈(profile 16%);FP16表读一半字节。按token量算应
能等比例翻译到评测。代价:embedding权重存FP16微小精度损失,须先测AUC。默认关。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-15 12:26:55 +08:00
parent cb2913cda8
commit adc99b5b41
2 changed files with 7 additions and 2 deletions
+4 -2
View File
@@ -50,6 +50,7 @@ CONFIG = {
"vectorize_moe": True, # True=稠密向量化MoE(无同步点)False=原逐expert循环(.nonzero同步)
"fuse_embedding": True, # True=28个slot的查表+池化融合为1次(减per-batch kernel启动)
"syncfree_mask": True, # True=用searchsorted构造因果mask(无同步)False=repeat_interleave(同步)
"emb_fp16": False, # True=Embedding表也转FP16(查表带宽减半,可能微动AUC)False=保留FP32
"compile": False, # 是否 torch.compile(实测慢5×,勿开)
}
@@ -700,8 +701,9 @@ def load_model(ckpt_path, device='cuda:0'):
if CONFIG["fp16"]:
model = model.half()
# Embedding 始终保留 FP32int 索引查表,不受浮点精度影响
model.rep_encoder.emb = model.rep_encoder.emb.to(torch.float32)
# 默认 Embedding 保留 FP32emb_fp16=True 时保持 FP16(查表带宽减半
if not CONFIG.get("emb_fp16", False):
model.rep_encoder.emb = model.rep_encoder.emb.to(torch.float32)
# 额外保留 FP32 的精度敏感模块(输入/输出自动转换)
for name, module in model.named_modules():
if name and any(name.startswith(p) for p in CONFIG["keep_fp32_modules"]):