feat: movedev_rep — 在move_batch_to_device(不计时/主进程/有模型有数据)算rep,model跳过embedding

collate_rep 评测端回退(疑num_workers>0子进程无模型)。move_batch_to_device官方明确不计入、
在主进程model(batch)之前调用→有CUDA+_MODEL_REF+batch数据,避开数据访问/调用顺序/子进程三大坑。
rep逐位等价。bench --no-movedev-rep 对照。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-16 19:37:34 +08:00
parent e1ad26867e
commit 4ea6d57a07
2 changed files with 17 additions and 2 deletions
+13 -2
View File
@@ -57,7 +57,8 @@ CONFIG = {
"compile": False, # 是否 torch.compile(实测慢5×,勿开)
# 预计算三种实现在评测端均回退(load_model 拿不到数据)。改走 collate(定义上不计时、必有数据)。
"precompute_rep": False, # True=load_model预计算(评测端三连回退,本地可跑见RISKS.md)
"collate_rep": True, # True=在 collate_fn(不计时)就地算RepEncoder存batch["rep"],model(batch)跳过embedding
"collate_rep": False, # True=在 collate_fn 算rep(评测num_workers>0时子进程无模型→回退)
"movedev_rep": True, # True=在 move_batch_to_device(不计时/主进程/有模型有数据)算rep存batch["rep"]
}
@@ -347,7 +348,17 @@ def make_collate_fn(max_slot_id):
def move_batch_to_device(batch, device):
if isinstance(batch, dict):
return {k: move_batch_to_device(v, device) for k, v in batch.items()}
moved = {k: move_batch_to_device(v, device) for k, v in batch.items()}
# move_batch_to_device 不计时、跑在主进程(有CUDA+模型) → 就地算 RepEncoder
# model(batch) 用 batch["rep"] 跳过 embedding。失败则不加(安全回退到 model 内现算)。
if (CONFIG.get("movedev_rep", False) and _MODEL_REF is not None
and 1 in moved and "rep" not in moved):
try:
with torch.inference_mode():
moved["rep"] = _MODEL_REF.rep_encoder(moved)
except Exception:
pass
return moved
elif isinstance(batch, (list, tuple)):
return [move_batch_to_device(x, device) for x in batch]
elif torch.is_tensor(batch):