feat: movedev_rep — 在move_batch_to_device(不计时/主进程/有模型有数据)算rep,model跳过embedding
collate_rep 评测端回退(疑num_workers>0子进程无模型)。move_batch_to_device官方明确不计入、 在主进程model(batch)之前调用→有CUDA+_MODEL_REF+batch数据,避开数据访问/调用顺序/子进程三大坑。 rep逐位等价。bench --no-movedev-rep 对照。 Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
+13
-2
@@ -57,7 +57,8 @@ CONFIG = {
|
||||
"compile": False, # 是否 torch.compile(实测慢5×,勿开)
|
||||
# 预计算三种实现在评测端均回退(load_model 拿不到数据)。改走 collate(定义上不计时、必有数据)。
|
||||
"precompute_rep": False, # True=load_model预计算(评测端三连回退,本地可跑见RISKS.md)
|
||||
"collate_rep": True, # True=在 collate_fn(不计时)就地算RepEncoder存batch["rep"],model(batch)跳过embedding
|
||||
"collate_rep": False, # True=在 collate_fn 算rep(评测num_workers>0时子进程无模型→回退)
|
||||
"movedev_rep": True, # True=在 move_batch_to_device(不计时/主进程/有模型有数据)算rep存batch["rep"]
|
||||
}
|
||||
|
||||
|
||||
@@ -347,7 +348,17 @@ def make_collate_fn(max_slot_id):
|
||||
|
||||
def move_batch_to_device(batch, device):
|
||||
if isinstance(batch, dict):
|
||||
return {k: move_batch_to_device(v, device) for k, v in batch.items()}
|
||||
moved = {k: move_batch_to_device(v, device) for k, v in batch.items()}
|
||||
# move_batch_to_device 不计时、跑在主进程(有CUDA+模型) → 就地算 RepEncoder,
|
||||
# model(batch) 用 batch["rep"] 跳过 embedding。失败则不加(安全回退到 model 内现算)。
|
||||
if (CONFIG.get("movedev_rep", False) and _MODEL_REF is not None
|
||||
and 1 in moved and "rep" not in moved):
|
||||
try:
|
||||
with torch.inference_mode():
|
||||
moved["rep"] = _MODEL_REF.rep_encoder(moved)
|
||||
except Exception:
|
||||
pass
|
||||
return moved
|
||||
elif isinstance(batch, (list, tuple)):
|
||||
return [move_batch_to_device(x, device) for x in batch]
|
||||
elif torch.is_tensor(batch):
|
||||
|
||||
Reference in New Issue
Block a user