From adc99b5b41fe1037d0e7a8ab9acabd1931564ea8 Mon Sep 17 00:00:00 2001 From: OwnerSunshine530 Date: Mon, 15 Jun 2026 12:26:55 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20emb=5Ffp16=20=E9=80=89=E9=A1=B9(Embeddi?= =?UTF-8?q?ng=E8=A1=A8=E8=BD=ACFP16,=E6=9F=A5=E8=A1=A8=E5=B8=A6=E5=AE=BD?= =?UTF-8?q?=E5=87=8F=E5=8D=8A)=EF=BC=9Bbench=20--emb-fp16?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit embedding查表是显存带宽瓶颈(profile 16%);FP16表读一半字节。按token量算应 能等比例翻译到评测。代价:embedding权重存FP16微小精度损失,须先测AUC。默认关。 Co-Authored-By: Claude Opus 4.8 --- 代码/code/bench.py | 3 +++ 代码/code/infer.py | 6 ++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/代码/code/bench.py b/代码/code/bench.py index d922812..adcdb21 100644 --- a/代码/code/bench.py +++ b/代码/code/bench.py @@ -296,6 +296,7 @@ def _parse_args(): ap.add_argument("--moe", choices=["dense", "loop"], default=None, help="MoE实现:dense=向量化(新), loop=逐expert循环(原)") ap.add_argument("--compile", action="store_true", help="开启 torch.compile") + ap.add_argument("--emb-fp16", action="store_true", help="Embedding表转FP16(查表带宽减半,测AUC)") ap.add_argument("--profile", type=int, default=None, metavar="N", help="剖析前 N 个 batch,打印按 CUDA 耗时排序的算子表(定位瓶颈)") ap.add_argument("--rebuild", action="store_true", help="强制重建过滤缓存") @@ -325,6 +326,8 @@ if __name__ == "__main__": cfg["attn"] = a.attn if a.moe is not None: cfg["vectorize_moe"] = (a.moe == "dense") + if a.emb_fp16: + cfg["emb_fp16"] = True if a.compile: cfg["compile"] = True if a.profile is not None: diff --git a/代码/code/infer.py b/代码/code/infer.py index 7564797..8ddb09a 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -50,6 +50,7 @@ CONFIG = { "vectorize_moe": True, # True=稠密向量化MoE(无同步点);False=原逐expert循环(.nonzero同步) "fuse_embedding": True, # True=28个slot的查表+池化融合为1次(减per-batch kernel启动) "syncfree_mask": True, # True=用searchsorted构造因果mask(无同步);False=repeat_interleave(同步) + "emb_fp16": False, # True=Embedding表也转FP16(查表带宽减半,可能微动AUC);False=保留FP32 "compile": False, # 是否 torch.compile(实测慢5×,勿开) } @@ -700,8 +701,9 @@ def load_model(ckpt_path, device='cuda:0'): if CONFIG["fp16"]: model = model.half() - # Embedding 始终保留 FP32(int 索引查表,不受浮点精度影响) - model.rep_encoder.emb = model.rep_encoder.emb.to(torch.float32) + # 默认 Embedding 保留 FP32;emb_fp16=True 时保持 FP16(查表带宽减半) + if not CONFIG.get("emb_fp16", False): + model.rep_encoder.emb = model.rep_encoder.emb.to(torch.float32) # 额外保留 FP32 的精度敏感模块(输入/输出自动转换) for name, module in model.named_modules(): if name and any(name.startswith(p) for p in CONFIG["keep_fp32_modules"]):