From 1249bbdbbc53d16371cd2dd6b9c6c7d107fdf30c Mon Sep 17 00:00:00 2001 From: OwnerSunshine530 Date: Mon, 15 Jun 2026 12:39:10 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20emb=5Ffp16=20=E9=BB=98=E8=AE=A4?= =?UTF-8?q?=E5=BC=80=E5=90=AF(=E6=9C=AC=E5=9C=B0AUC=200.75932=E2=89=88?= =?UTF-8?q?=E6=97=A0=E6=8D=9F,=E6=9F=A5=E8=A1=A8=E5=B8=A6=E5=AE=BD?= =?UTF-8?q?=E5=87=8F=E5=8D=8A)=EF=BC=9B=E4=BF=AE=E6=AD=A3=E6=89=93?= =?UTF-8?q?=E5=8D=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Opus 4.8 --- 代码/code/infer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/代码/code/infer.py b/代码/code/infer.py index 8ddb09a..a5b033c 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -50,7 +50,7 @@ CONFIG = { "vectorize_moe": True, # True=稠密向量化MoE(无同步点);False=原逐expert循环(.nonzero同步) "fuse_embedding": True, # True=28个slot的查表+池化融合为1次(减per-batch kernel启动) "syncfree_mask": True, # True=用searchsorted构造因果mask(无同步);False=repeat_interleave(同步) - "emb_fp16": False, # True=Embedding表也转FP16(查表带宽减半,可能微动AUC);False=保留FP32 + "emb_fp16": True, # True=Embedding表转FP16(查表带宽减半,实测AUC 0.75932≈无损) "compile": False, # 是否 torch.compile(实测慢5×,勿开) } @@ -708,8 +708,9 @@ def load_model(ckpt_path, device='cuda:0'): for name, module in model.named_modules(): if name and any(name.startswith(p) for p in CONFIG["keep_fp32_modules"]): _force_fp32_io(module) - print(f"[INFO] FP16 on; FP32-kept: " - f"{('rep_encoder.emb',) + tuple(CONFIG['keep_fp32_modules'])}") + emb_note = "emb=FP16" if CONFIG.get("emb_fp16", False) else "emb=FP32" + print(f"[INFO] FP16 on; {emb_note}; extra FP32-kept: " + f"{tuple(CONFIG['keep_fp32_modules'])}") else: model = model.float() print("[INFO] FP32 reference (no half)")