feat: collate_rep — 在collate_fn(定义上不计时)就地算RepEncoder存batch[rep],model跳过embedding

collate 在两次model(batch)之间运行(取下一batch),永不在计时窗口;且必有数据、必在
load_model之后。比load_model预计算(3连回退)可靠。rep逐位等价(同rep_encoder同batch)。
load_model设_MODEL_REF供collate用;forward优先用batch[rep]。bench重排load_model先于建batch
以本地复现;默认collate_rep=True,--no-collate-rep对照。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-16 18:49:55 +08:00
parent ae7fce7d10
commit e1ad26867e
2 changed files with 32 additions and 7 deletions
+8 -3
View File
@@ -232,6 +232,10 @@ def run_once(config_override=None, batch_size=50, max_batches=None,
ds, batch_size=batch_size, shuffle=False, num_workers=0,
collate_fn=infer.make_collate_fn(ds.max_slot_id),
)
# load_model 先于 batch 构建,使 collate_fn 能拿到模型就地算 rep(镜像评测流程)
model, dev = infer.load_model(ckpt_path=None)
cuda = (dev.type == "cuda")
batches = []
for b in loader:
batches.append(infer.move_batch_to_device(b, torch.device("cpu")))
@@ -242,9 +246,6 @@ def run_once(config_override=None, batch_size=50, max_batches=None,
import gc
gc.collect()
model, dev = infer.load_model(ckpt_path=None)
cuda = (dev.type == "cuda")
if eval_precompute and model._rep_cache is not None:
print(f"[BENCH] eval-path rep cache (load_model): {model._rep_cache[0].numel()} items")
@@ -327,6 +328,8 @@ def _parse_args():
help="预计算RepEncoder缓存,model(batch)跳过embedding层(从batches自建)")
ap.add_argument("--eval-precompute", action="store_true",
help="走评测路径:load_model 流式过滤自动预计算(本地验证不OOM)")
ap.add_argument("--no-collate-rep", action="store_true",
help="关闭 collate 内算 rep(用于对照基准)")
ap.add_argument("--profile", type=int, default=None, metavar="N",
help="剖析前 N 个 batch,打印按 CUDA 耗时排序的算子表(定位瓶颈)")
ap.add_argument("--rebuild", action="store_true", help="强制重建过滤缓存")
@@ -368,6 +371,8 @@ if __name__ == "__main__":
cfg["precompute_rep"] = True
if a.eval_precompute:
cfg["eval_precompute"] = True
if a.no_collate_rep:
cfg["collate_rep"] = False
if a.compile:
cfg["compile"] = True
if a.profile is not None: