feat: movedev_rep — 在move_batch_to_device(不计时/主进程/有模型有数据)算rep,model跳过embedding
collate_rep 评测端回退(疑num_workers>0子进程无模型)。move_batch_to_device官方明确不计入、 在主进程model(batch)之前调用→有CUDA+_MODEL_REF+batch数据,避开数据访问/调用顺序/子进程三大坑。 rep逐位等价。bench --no-movedev-rep 对照。 Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -330,6 +330,8 @@ def _parse_args():
|
||||
help="走评测路径:load_model 流式过滤自动预计算(本地验证不OOM)")
|
||||
ap.add_argument("--no-collate-rep", action="store_true",
|
||||
help="关闭 collate 内算 rep(用于对照基准)")
|
||||
ap.add_argument("--no-movedev-rep", action="store_true",
|
||||
help="关闭 move_batch_to_device 内算 rep(用于对照基准)")
|
||||
ap.add_argument("--profile", type=int, default=None, metavar="N",
|
||||
help="剖析前 N 个 batch,打印按 CUDA 耗时排序的算子表(定位瓶颈)")
|
||||
ap.add_argument("--rebuild", action="store_true", help="强制重建过滤缓存")
|
||||
@@ -373,6 +375,8 @@ if __name__ == "__main__":
|
||||
cfg["eval_precompute"] = True
|
||||
if a.no_collate_rep:
|
||||
cfg["collate_rep"] = False
|
||||
if a.no_movedev_rep:
|
||||
cfg["movedev_rep"] = False
|
||||
if a.compile:
|
||||
cfg["compile"] = True
|
||||
if a.profile is not None:
|
||||
|
||||
+13
-2
@@ -57,7 +57,8 @@ CONFIG = {
|
||||
"compile": False, # 是否 torch.compile(实测慢5×,勿开)
|
||||
# 预计算三种实现在评测端均回退(load_model 拿不到数据)。改走 collate(定义上不计时、必有数据)。
|
||||
"precompute_rep": False, # True=load_model预计算(评测端三连回退,本地可跑见RISKS.md)
|
||||
"collate_rep": True, # True=在 collate_fn(不计时)就地算RepEncoder存batch["rep"],model(batch)跳过embedding
|
||||
"collate_rep": False, # True=在 collate_fn 算rep(评测num_workers>0时子进程无模型→回退)
|
||||
"movedev_rep": True, # True=在 move_batch_to_device(不计时/主进程/有模型有数据)算rep存batch["rep"]
|
||||
}
|
||||
|
||||
|
||||
@@ -347,7 +348,17 @@ def make_collate_fn(max_slot_id):
|
||||
|
||||
def move_batch_to_device(batch, device):
|
||||
if isinstance(batch, dict):
|
||||
return {k: move_batch_to_device(v, device) for k, v in batch.items()}
|
||||
moved = {k: move_batch_to_device(v, device) for k, v in batch.items()}
|
||||
# move_batch_to_device 不计时、跑在主进程(有CUDA+模型) → 就地算 RepEncoder,
|
||||
# model(batch) 用 batch["rep"] 跳过 embedding。失败则不加(安全回退到 model 内现算)。
|
||||
if (CONFIG.get("movedev_rep", False) and _MODEL_REF is not None
|
||||
and 1 in moved and "rep" not in moved):
|
||||
try:
|
||||
with torch.inference_mode():
|
||||
moved["rep"] = _MODEL_REF.rep_encoder(moved)
|
||||
except Exception:
|
||||
pass
|
||||
return moved
|
||||
elif isinstance(batch, (list, tuple)):
|
||||
return [move_batch_to_device(x, device) for x in batch]
|
||||
elif torch.is_tensor(batch):
|
||||
|
||||
Reference in New Issue
Block a user