diff --git a/代码/code/bench.py b/代码/code/bench.py index 23e6122..0557e16 100644 --- a/代码/code/bench.py +++ b/代码/code/bench.py @@ -330,6 +330,8 @@ def _parse_args(): help="走评测路径:load_model 流式过滤自动预计算(本地验证不OOM)") ap.add_argument("--no-collate-rep", action="store_true", help="关闭 collate 内算 rep(用于对照基准)") + ap.add_argument("--no-movedev-rep", action="store_true", + help="关闭 move_batch_to_device 内算 rep(用于对照基准)") ap.add_argument("--profile", type=int, default=None, metavar="N", help="剖析前 N 个 batch,打印按 CUDA 耗时排序的算子表(定位瓶颈)") ap.add_argument("--rebuild", action="store_true", help="强制重建过滤缓存") @@ -373,6 +375,8 @@ if __name__ == "__main__": cfg["eval_precompute"] = True if a.no_collate_rep: cfg["collate_rep"] = False + if a.no_movedev_rep: + cfg["movedev_rep"] = False if a.compile: cfg["compile"] = True if a.profile is not None: diff --git a/代码/code/infer.py b/代码/code/infer.py index c00df25..b5aa4e4 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -57,7 +57,8 @@ CONFIG = { "compile": False, # 是否 torch.compile(实测慢5×,勿开) # 预计算三种实现在评测端均回退(load_model 拿不到数据)。改走 collate(定义上不计时、必有数据)。 "precompute_rep": False, # True=load_model预计算(评测端三连回退,本地可跑见RISKS.md) - "collate_rep": True, # True=在 collate_fn(不计时)就地算RepEncoder存batch["rep"],model(batch)跳过embedding + "collate_rep": False, # True=在 collate_fn 算rep(评测num_workers>0时子进程无模型→回退) + "movedev_rep": True, # True=在 move_batch_to_device(不计时/主进程/有模型有数据)算rep存batch["rep"] } @@ -347,7 +348,17 @@ def make_collate_fn(max_slot_id): def move_batch_to_device(batch, device): if isinstance(batch, dict): - return {k: move_batch_to_device(v, device) for k, v in batch.items()} + moved = {k: move_batch_to_device(v, device) for k, v in batch.items()} + # move_batch_to_device 不计时、跑在主进程(有CUDA+模型) → 就地算 RepEncoder, + # model(batch) 用 batch["rep"] 跳过 embedding。失败则不加(安全回退到 model 内现算)。 + if (CONFIG.get("movedev_rep", False) and _MODEL_REF is not None + and 1 in moved and "rep" not in moved): + try: + with torch.inference_mode(): + moved["rep"] = _MODEL_REF.rep_encoder(moved) + except Exception: + pass + return moved elif isinstance(batch, (list, tuple)): return [move_batch_to_device(x, device) for x in batch] elif torch.is_tensor(batch):