fix: 修OOM — load_model预计算改流式只加载测试用户+直接逐item算(不建Dataset)+算完释放

评测异常根因:load_model全量load_sample_files与评测自身数据双倍内存OOM。
改:_load_test_user_items流式过滤(仅测试用户~1.5M)、build_rep_cache直接从item_dict
逐item算(省掉user_items~8GB拷贝)、算完del+gc。bench加--eval-precompute本地真跑
load_model这条路验证不OOM。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-16 12:19:30 +08:00
parent db5d0b222a
commit 9042655fed
2 changed files with 91 additions and 30 deletions
+12 -3
View File
@@ -209,11 +209,13 @@ def run_once(config_override=None, batch_size=50, max_batches=None,
if max_feasign_per_slot is None:
max_feasign_per_slot = {1: 2}
# 本地用已加载的过滤数据自建 rep 缓存,禁止 load_model 自动加载全量数据集
# precompute_rep: 从已加载的过滤 batches 自建缓存(测 gather);
# eval_precompute: 走真正的评测路径(load_model 流式过滤自动预计算)
want_precompute = bool(config_override.pop("precompute_rep", False))
eval_precompute = bool(config_override.pop("eval_precompute", False))
infer.CONFIG.update(config_override)
infer.CONFIG["sync_timing"] = True
infer.CONFIG["precompute_rep"] = False
infer.CONFIG["precompute_rep"] = eval_precompute # True 时让 load_model 自动预计算
cur = Path(__file__).parent
ref = cur / "dataset"
@@ -243,8 +245,11 @@ def run_once(config_override=None, batch_size=50, max_batches=None,
model, dev = infer.load_model(ckpt_path=None)
cuda = (dev.type == "cuda")
if eval_precompute and model._rep_cache is not None:
print(f"[BENCH] eval-path rep cache (load_model): {model._rep_cache[0].numel()} items")
# 本地从已建好的 batches 构造 rep 缓存(复用 batches、省内存;不计入计时)
if want_precompute:
if want_precompute and not eval_precompute:
lc, ec = [], []
with torch.inference_mode():
for b in batches:
@@ -320,6 +325,8 @@ def _parse_args():
ap.add_argument("--sparse-pool", action="store_true", help="稀疏矩阵乘做池化(段内高重复时省)")
ap.add_argument("--precompute-rep", action="store_true",
help="预计算RepEncoder缓存,model(batch)跳过embedding层(从batches自建)")
ap.add_argument("--eval-precompute", action="store_true",
help="走评测路径:load_model 流式过滤自动预计算(本地验证不OOM)")
ap.add_argument("--profile", type=int, default=None, metavar="N",
help="剖析前 N 个 batch,打印按 CUDA 耗时排序的算子表(定位瓶颈)")
ap.add_argument("--rebuild", action="store_true", help="强制重建过滤缓存")
@@ -359,6 +366,8 @@ if __name__ == "__main__":
cfg["sparse_pool"] = True
if a.precompute_rep:
cfg["precompute_rep"] = True
if a.eval_precompute:
cfg["eval_precompute"] = True
if a.compile:
cfg["compile"] = True
if a.profile is not None: