From 9042655fed07c004f1c30206c2cf210386bf4272 Mon Sep 17 00:00:00 2001 From: OwnerSunshine530 Date: Tue, 16 Jun 2026 12:19:30 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AEOOM=20=E2=80=94=20load=5Fmodel?= =?UTF-8?q?=E9=A2=84=E8=AE=A1=E7=AE=97=E6=94=B9=E6=B5=81=E5=BC=8F=E5=8F=AA?= =?UTF-8?q?=E5=8A=A0=E8=BD=BD=E6=B5=8B=E8=AF=95=E7=94=A8=E6=88=B7+?= =?UTF-8?q?=E7=9B=B4=E6=8E=A5=E9=80=90item=E7=AE=97(=E4=B8=8D=E5=BB=BAData?= =?UTF-8?q?set)+=E7=AE=97=E5=AE=8C=E9=87=8A=E6=94=BE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 评测异常根因:load_model全量load_sample_files与评测自身数据双倍内存OOM。 改:_load_test_user_items流式过滤(仅测试用户~1.5M)、build_rep_cache直接从item_dict 逐item算(省掉user_items~8GB拷贝)、算完del+gc。bench加--eval-precompute本地真跑 load_model这条路验证不OOM。 Co-Authored-By: Claude Opus 4.8 --- 代码/code/bench.py | 15 +++++-- 代码/code/infer.py | 106 +++++++++++++++++++++++++++++++++------------ 2 files changed, 91 insertions(+), 30 deletions(-) diff --git a/代码/code/bench.py b/代码/code/bench.py index 611d8be..86f9a24 100644 --- a/代码/code/bench.py +++ b/代码/code/bench.py @@ -209,11 +209,13 @@ def run_once(config_override=None, batch_size=50, max_batches=None, if max_feasign_per_slot is None: max_feasign_per_slot = {1: 2} - # 本地用已加载的过滤数据自建 rep 缓存,禁止 load_model 自动加载全量数据集 + # precompute_rep: 从已加载的过滤 batches 自建缓存(测 gather); + # eval_precompute: 走真正的评测路径(load_model 流式过滤自动预计算) want_precompute = bool(config_override.pop("precompute_rep", False)) + eval_precompute = bool(config_override.pop("eval_precompute", False)) infer.CONFIG.update(config_override) infer.CONFIG["sync_timing"] = True - infer.CONFIG["precompute_rep"] = False + infer.CONFIG["precompute_rep"] = eval_precompute # True 时让 load_model 自动预计算 cur = Path(__file__).parent ref = cur / "dataset" @@ -243,8 +245,11 @@ def run_once(config_override=None, batch_size=50, max_batches=None, model, dev = infer.load_model(ckpt_path=None) cuda = (dev.type == "cuda") + if eval_precompute and model._rep_cache is not None: + print(f"[BENCH] eval-path rep cache (load_model): {model._rep_cache[0].numel()} items") + # 本地从已建好的 batches 构造 rep 缓存(复用 batches、省内存;不计入计时) - if want_precompute: + if want_precompute and not eval_precompute: lc, ec = [], [] with torch.inference_mode(): for b in batches: @@ -320,6 +325,8 @@ def _parse_args(): ap.add_argument("--sparse-pool", action="store_true", help="稀疏矩阵乘做池化(段内高重复时省)") ap.add_argument("--precompute-rep", action="store_true", help="预计算RepEncoder缓存,model(batch)跳过embedding层(从batches自建)") + ap.add_argument("--eval-precompute", action="store_true", + help="走评测路径:load_model 流式过滤自动预计算(本地验证不OOM)") ap.add_argument("--profile", type=int, default=None, metavar="N", help="剖析前 N 个 batch,打印按 CUDA 耗时排序的算子表(定位瓶颈)") ap.add_argument("--rebuild", action="store_true", help="强制重建过滤缓存") @@ -359,6 +366,8 @@ if __name__ == "__main__": cfg["sparse_pool"] = True if a.precompute_rep: cfg["precompute_rep"] = True + if a.eval_precompute: + cfg["eval_precompute"] = True if a.compile: cfg["compile"] = True if a.profile is not None: diff --git a/代码/code/infer.py b/代码/code/infer.py index 88c2f40..86883af 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -720,31 +720,82 @@ class CTRModel(nn.Module): # RepEncoder 预计算缓存 # ============================================================ -def build_rep_cache(model, item_dict, user_seq, test_logids_ordered, - max_feasign_per_slot, device, batch_users=200): - """预计算所有 item 的 RepEncoder 向量(context-free),按 logid 排序存入 model._rep_cache。 +def _load_test_user_items(ds_dir): + """流式只加载"测试用户"的 item(避免全量 OOM)。返回 item_dict(仅测试用户)。""" + test_csv = ds_dir / "test.csv" + history = ds_dir / "history" + test_users = set() + with open(test_csv) as f: + for line in f: + line = line.strip() + if not line: + continue + parts = line.split(",") + if len(parts) >= 2: + test_users.add(int(parts[1])) + files = (sorted(history.glob("*.csv")) if history.exists() else []) + [test_csv] + item_dict = {} + for fp in files: + has_clk = _detect_has_clk(fp) + min_parts = 5 if has_clk else 4 + with open(fp) as f: + for line in f: + line = line.strip() + if not line: + continue + parts = line.split(",") + if len(parts) < min_parts: + continue + if int(parts[1]) not in test_users: + continue + logid = int(parts[0]) + fs = 5 if has_clk else 4 + signs, slots = [], [] + for pair in parts[fs:]: + if ":" in pair: + s, sl = pair.split(":", 1) + signs.append(int(s)) + slots.append(int(sl)) + item_dict[logid] = { + "signs": np.array(signs, dtype=np.int64), + "slots": np.array(slots, dtype=np.int64), + } + return item_dict - 复用 CTRTestSeqDataset + collate + model.rep_encoder,保证与 model(batch) 内的 - RepEncoder 输出逐位一致。注意:必须用与评测端一致的 max_feasign_per_slot(基线为 {1:2}), - 否则缓存的 item 向量与 batch 实际特征不符。 + +def build_rep_cache(model, item_dict, max_feasign_per_slot, device, chunk=4000, max_slot_id=28): + """直接从 item_dict 逐 item 预计算 RepEncoder 向量(不建 CTRTestSeqDataset,省内存)。 + + 每个 item 作为一个 segment,逐 slot 拼 values/offsets,跑 model.rep_encoder, + 与 model(batch) 内的 RepEncoder 输出逐位一致。必须用与评测端一致的 + max_feasign_per_slot(基线 {1:2}),否则缓存向量与 batch 实际特征不符。 """ - ds = CTRTestSeqDataset( - test_logids_ordered=test_logids_ordered, item_dict=item_dict, - user_seq=user_seq, max_feasign_per_slot=max_feasign_per_slot, max_ctx_len=None) - loader = DataLoader(ds, batch_size=batch_users, shuffle=False, num_workers=0, - collate_fn=make_collate_fn(ds.max_slot_id)) - logid_chunks, emb_chunks = [], [] + logids_sorted = sorted(item_dict.keys()) + emb_chunks = [] model.eval() with torch.inference_mode(): - for batch in loader: - batch = move_batch_to_device(batch, device) - rep = model.rep_encoder(batch) # [num_tokens, d_model] - logid_chunks.append(batch["logid"].to(device)) - emb_chunks.append(rep) - logids = torch.cat(logid_chunks) + for i in range(0, len(logids_sorted), chunk): + cl = logids_sorted[i:i + chunk] + slot_vals = {s: [] for s in range(1, max_slot_id + 1)} + slot_offs = {s: [0] for s in range(1, max_slot_id + 1)} + for lid in cl: + rec = item_dict[lid] + by = defaultdict(list) + for s, sl in zip(rec["signs"].tolist(), rec["slots"].tolist()): + by[sl].append(s) + for slot in range(1, max_slot_id + 1): + ss = by.get(slot, []) + if max_feasign_per_slot and max_feasign_per_slot.get(slot, -1) != -1: + ss = ss[:max_feasign_per_slot[slot]] + slot_vals[slot].extend(ss) + slot_offs[slot].append(len(slot_vals[slot])) + batch = {slot: (torch.tensor(slot_vals[slot], dtype=torch.long, device=device), + torch.tensor(slot_offs[slot], dtype=torch.long, device=device)) + for slot in range(1, max_slot_id + 1)} + emb_chunks.append(model.rep_encoder(batch)) # [len(cl), d_model] + logids = torch.tensor(logids_sorted, dtype=torch.long, device=device) # 已有序 emb = torch.cat(emb_chunks) - order = torch.argsort(logids) - model._rep_cache = (logids[order].contiguous(), emb[order].contiguous()) + model._rep_cache = (logids.contiguous(), emb.contiguous()) return model._rep_cache @@ -840,13 +891,14 @@ def load_model(ckpt_path, device='cuda:0'): ds_dir = cand break if ds_dir is not None: - history = ds_dir / "history" - test_csv = ds_dir / "test.csv" - files = (sorted(history.glob("*.csv")) if history.exists() else []) + [test_csv] - item_dict, user_seq = load_sample_files(files) - test_logids = list(load_logids_from_file(test_csv)) - build_rep_cache(model, item_dict, user_seq, test_logids, {1: 2}, dev) - print(f"[INFO] rep cache built: {model._rep_cache[0].numel()} items") + # 流式只加载测试用户的 item(避免全量 OOM),算完即释放 + item_dict = _load_test_user_items(ds_dir) + build_rep_cache(model, item_dict, {1: 2}, dev) + n_items = model._rep_cache[0].numel() + del item_dict + import gc + gc.collect() + print(f"[INFO] rep cache built (stream-filtered): {n_items} items") else: print("[INFO] dataset/ not found, skip rep precompute (fallback to in-batch)") except Exception as e: