feat: emb_fp16 选项(Embedding表转FP16,查表带宽减半);bench --emb-fp16
embedding查表是显存带宽瓶颈(profile 16%);FP16表读一半字节。按token量算应 能等比例翻译到评测。代价:embedding权重存FP16微小精度损失,须先测AUC。默认关。 Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -296,6 +296,7 @@ def _parse_args():
|
|||||||
ap.add_argument("--moe", choices=["dense", "loop"], default=None,
|
ap.add_argument("--moe", choices=["dense", "loop"], default=None,
|
||||||
help="MoE实现:dense=向量化(新), loop=逐expert循环(原)")
|
help="MoE实现:dense=向量化(新), loop=逐expert循环(原)")
|
||||||
ap.add_argument("--compile", action="store_true", help="开启 torch.compile")
|
ap.add_argument("--compile", action="store_true", help="开启 torch.compile")
|
||||||
|
ap.add_argument("--emb-fp16", action="store_true", help="Embedding表转FP16(查表带宽减半,测AUC)")
|
||||||
ap.add_argument("--profile", type=int, default=None, metavar="N",
|
ap.add_argument("--profile", type=int, default=None, metavar="N",
|
||||||
help="剖析前 N 个 batch,打印按 CUDA 耗时排序的算子表(定位瓶颈)")
|
help="剖析前 N 个 batch,打印按 CUDA 耗时排序的算子表(定位瓶颈)")
|
||||||
ap.add_argument("--rebuild", action="store_true", help="强制重建过滤缓存")
|
ap.add_argument("--rebuild", action="store_true", help="强制重建过滤缓存")
|
||||||
@@ -325,6 +326,8 @@ if __name__ == "__main__":
|
|||||||
cfg["attn"] = a.attn
|
cfg["attn"] = a.attn
|
||||||
if a.moe is not None:
|
if a.moe is not None:
|
||||||
cfg["vectorize_moe"] = (a.moe == "dense")
|
cfg["vectorize_moe"] = (a.moe == "dense")
|
||||||
|
if a.emb_fp16:
|
||||||
|
cfg["emb_fp16"] = True
|
||||||
if a.compile:
|
if a.compile:
|
||||||
cfg["compile"] = True
|
cfg["compile"] = True
|
||||||
if a.profile is not None:
|
if a.profile is not None:
|
||||||
|
|||||||
+3
-1
@@ -50,6 +50,7 @@ CONFIG = {
|
|||||||
"vectorize_moe": True, # True=稠密向量化MoE(无同步点);False=原逐expert循环(.nonzero同步)
|
"vectorize_moe": True, # True=稠密向量化MoE(无同步点);False=原逐expert循环(.nonzero同步)
|
||||||
"fuse_embedding": True, # True=28个slot的查表+池化融合为1次(减per-batch kernel启动)
|
"fuse_embedding": True, # True=28个slot的查表+池化融合为1次(减per-batch kernel启动)
|
||||||
"syncfree_mask": True, # True=用searchsorted构造因果mask(无同步);False=repeat_interleave(同步)
|
"syncfree_mask": True, # True=用searchsorted构造因果mask(无同步);False=repeat_interleave(同步)
|
||||||
|
"emb_fp16": False, # True=Embedding表也转FP16(查表带宽减半,可能微动AUC);False=保留FP32
|
||||||
"compile": False, # 是否 torch.compile(实测慢5×,勿开)
|
"compile": False, # 是否 torch.compile(实测慢5×,勿开)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -700,7 +701,8 @@ def load_model(ckpt_path, device='cuda:0'):
|
|||||||
|
|
||||||
if CONFIG["fp16"]:
|
if CONFIG["fp16"]:
|
||||||
model = model.half()
|
model = model.half()
|
||||||
# Embedding 始终保留 FP32(int 索引查表,不受浮点精度影响)
|
# 默认 Embedding 保留 FP32;emb_fp16=True 时保持 FP16(查表带宽减半)
|
||||||
|
if not CONFIG.get("emb_fp16", False):
|
||||||
model.rep_encoder.emb = model.rep_encoder.emb.to(torch.float32)
|
model.rep_encoder.emb = model.rep_encoder.emb.to(torch.float32)
|
||||||
# 额外保留 FP32 的精度敏感模块(输入/输出自动转换)
|
# 额外保留 FP32 的精度敏感模块(输入/输出自动转换)
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
|
|||||||
Reference in New Issue
Block a user