From 4ea6d57a07720023e8abf28c3d55c037db7725b7 Mon Sep 17 00:00:00 2001 From: OwnerSunshine530 Date: Tue, 16 Jun 2026 19:37:34 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20movedev=5Frep=20=E2=80=94=20=E5=9C=A8mo?= =?UTF-8?q?ve=5Fbatch=5Fto=5Fdevice(=E4=B8=8D=E8=AE=A1=E6=97=B6/=E4=B8=BB?= =?UTF-8?q?=E8=BF=9B=E7=A8=8B/=E6=9C=89=E6=A8=A1=E5=9E=8B=E6=9C=89?= =?UTF-8?q?=E6=95=B0=E6=8D=AE)=E7=AE=97rep,model=E8=B7=B3=E8=BF=87embeddin?= =?UTF-8?q?g?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit collate_rep 评测端回退(疑num_workers>0子进程无模型)。move_batch_to_device官方明确不计入、 在主进程model(batch)之前调用→有CUDA+_MODEL_REF+batch数据,避开数据访问/调用顺序/子进程三大坑。 rep逐位等价。bench --no-movedev-rep 对照。 Co-Authored-By: Claude Opus 4.8 --- 代码/code/bench.py | 4 ++++ 代码/code/infer.py | 15 +++++++++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/代码/code/bench.py b/代码/code/bench.py index 23e6122..0557e16 100644 --- a/代码/code/bench.py +++ b/代码/code/bench.py @@ -330,6 +330,8 @@ def _parse_args(): help="走评测路径:load_model 流式过滤自动预计算(本地验证不OOM)") ap.add_argument("--no-collate-rep", action="store_true", help="关闭 collate 内算 rep(用于对照基准)") + ap.add_argument("--no-movedev-rep", action="store_true", + help="关闭 move_batch_to_device 内算 rep(用于对照基准)") ap.add_argument("--profile", type=int, default=None, metavar="N", help="剖析前 N 个 batch,打印按 CUDA 耗时排序的算子表(定位瓶颈)") ap.add_argument("--rebuild", action="store_true", help="强制重建过滤缓存") @@ -373,6 +375,8 @@ if __name__ == "__main__": cfg["eval_precompute"] = True if a.no_collate_rep: cfg["collate_rep"] = False + if a.no_movedev_rep: + cfg["movedev_rep"] = False if a.compile: cfg["compile"] = True if a.profile is not None: diff --git a/代码/code/infer.py b/代码/code/infer.py index c00df25..b5aa4e4 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -57,7 +57,8 @@ CONFIG = { "compile": False, # 是否 torch.compile(实测慢5×,勿开) # 预计算三种实现在评测端均回退(load_model 拿不到数据)。改走 collate(定义上不计时、必有数据)。 "precompute_rep": False, # True=load_model预计算(评测端三连回退,本地可跑见RISKS.md) - "collate_rep": True, # True=在 collate_fn(不计时)就地算RepEncoder存batch["rep"],model(batch)跳过embedding + "collate_rep": False, # True=在 collate_fn 算rep(评测num_workers>0时子进程无模型→回退) + "movedev_rep": True, # True=在 move_batch_to_device(不计时/主进程/有模型有数据)算rep存batch["rep"] } @@ -347,7 +348,17 @@ def make_collate_fn(max_slot_id): def move_batch_to_device(batch, device): if isinstance(batch, dict): - return {k: move_batch_to_device(v, device) for k, v in batch.items()} + moved = {k: move_batch_to_device(v, device) for k, v in batch.items()} + # move_batch_to_device 不计时、跑在主进程(有CUDA+模型) → 就地算 RepEncoder, + # model(batch) 用 batch["rep"] 跳过 embedding。失败则不加(安全回退到 model 内现算)。 + if (CONFIG.get("movedev_rep", False) and _MODEL_REF is not None + and 1 in moved and "rep" not in moved): + try: + with torch.inference_mode(): + moved["rep"] = _MODEL_REF.rep_encoder(moved) + except Exception: + pass + return moved elif isinstance(batch, (list, tuple)): return [move_batch_to_device(x, device) for x in batch] elif torch.is_tensor(batch):