From 0a971e67ac26e671b332a55d71ab9aac3dee25d3 Mon Sep 17 00:00:00 2001 From: OwnerSunshine530 Date: Sun, 14 Jun 2026 22:47:17 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E7=BC=93=E5=AD=98=E6=94=B9=E7=94=A8?= =?UTF-8?q?=E6=96=87=E6=9C=ACCSV(=E9=80=90=E8=A1=8C=E5=86=99)=E6=9B=BF?= =?UTF-8?q?=E4=BB=A3pickle=EF=BC=8C=E9=81=BF=E5=85=8D=E5=AE=B9=E5=99=A8cgr?= =?UTF-8?q?oup=20OOM=E9=9D=99=E9=BB=98=E6=9D=80=E8=BF=9B=E7=A8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit pickle.dump 150万记录的memo瞬间撑爆容器内存上限被杀;改为流式逐行写 保留的历史行到 cache_filtered_history.csv,读回用 load_sample_files。 Co-Authored-By: Claude Opus 4.8 --- 代码/code/bench.py | 255 ++++++++++++++++++--------------------------- 1 file changed, 104 insertions(+), 151 deletions(-) diff --git a/代码/code/bench.py b/代码/code/bench.py index 558e501..8bbbb1d 100644 --- a/代码/code/bench.py +++ b/代码/code/bench.py @@ -3,15 +3,19 @@ 不进提交包。**以子进程方式运行**(AI Studio 内核禁止 import torch): %cd /home/aistudio/code - !python bench.py --smoke 50 # 冒烟:只跑前 50 batch - !python bench.py # 默认基线 - !python bench.py --fp32 # FP32 天花板(Task 3) - !python bench.py --rebuild # 强制重建过滤缓存 + !python bench.py --diag # 诊断:序列长度分布 + sign-id 超界比例 + !python bench.py --smoke 50 # 冒烟:只跑前 50 batch + !python bench.py # 默认基线 + !python bench.py --fp32 # FP32 天花板 + !python bench.py --rebuild # 强制重建过滤缓存 -关键设计——只保留“测试用户”的数据: -不同用户被因果 mask 完全隔离,非测试用户的前向输出不参与打分;过滤掉它们 -对测试样本的 AUC/PCOC 没有任何影响,却能把数据量从 924 万条降到一小部分, -避免 CTRTestSeqDataset 构造时 OOM。过滤后的数据缓存到磁盘,后续秒级复用。 +只保留“测试用户”的数据:不同用户被因果 mask 完全隔离,非测试用户的前向输出 +不参与打分;过滤掉它们对测试样本的 AUC/PCOC 没有任何影响,却能把数据量从 +924 万条降到一小部分。 + +缓存用**文本 CSV**而非 pickle:容器 cgroup 内存有限,pickle.dump 大对象的 memo +会瞬间撑爆内存被静默 OOM-kill;逐行写 CSV 内存几乎不涨,再用 load_sample_files +读回,稳。 """ import os import sys @@ -46,116 +50,114 @@ def _test_user_ids(test_csv): return users -def _load_filtered(history_dir, test_csv, test_users): - """流式读取所有文件,只保留 userid ∈ test_users 的记录(不持有完整字典,防 OOM)。 - - 解析逻辑与 infer.load_sample_files 完全一致,只是多了一道用户过滤。 +def _stream_build(ref, cache_csv_path=None): + """流式过滤:构建 item_dict/user_seq;若给 cache_csv_path,同时把保留的历史行 + 原样逐行写入(低内存文本缓存,test.csv 直接复用、不进缓存)。 """ - files = (sorted(history_dir.glob("*.csv")) if history_dir.exists() else []) + [test_csv] + test_csv = ref / "test.csv" + history = ref / "history" + test_users = _test_user_ids(test_csv) + files = (sorted(history.glob("*.csv")) if history.exists() else []) + [test_csv] print(f"[BENCH] 流式过滤加载 {len(files)} 个文件(仅保留 {len(test_users)} 个测试用户)...") + item_dict = {} user_logs = defaultdict(list) - for fp in files: - has_clk = infer._detect_has_clk(fp) - min_parts = 5 if has_clk else 4 - kept = 0 - with open(fp) as f: - for line in f: - line = line.strip() - if not line: - continue - parts = line.split(",") - if len(parts) < min_parts: - continue - userid = int(parts[1]) - if userid not in test_users: - continue - logid = int(parts[0]) - adid = int(parts[2]) - if has_clk: - clk = int(parts[3]) - timestamp = int(parts[4]) - fs = 5 - else: - clk = 0 - timestamp = int(parts[3]) - fs = 4 - signs, slots = [], [] - for pair in parts[fs:]: - if ":" in pair: - s, sl = pair.split(":", 1) - signs.append(int(s)) - slots.append(int(sl)) - item_dict[logid] = { - "logid": logid, "userid": userid, "adid": adid, - "clk": clk, "timestamp": timestamp, - "signs": np.array(signs, dtype=np.int64), - "slots": np.array(slots, dtype=np.int64), - } - user_logs[userid].append((timestamp, logid)) - kept += 1 - print(f" {fp.name}: has_clk={has_clk}, kept={kept}") + cf = open(cache_csv_path, "w") if cache_csv_path else None + try: + for fp in files: + has_clk = infer._detect_has_clk(fp) + min_parts = 5 if has_clk else 4 + is_test = (Path(fp).name == test_csv.name) + kept = 0 + with open(fp) as f: + for raw in f: + line = raw.strip() + if not line: + continue + parts = line.split(",") + if len(parts) < min_parts: + continue + userid = int(parts[1]) + if userid not in test_users: + continue + if cf is not None and not is_test: # 只缓存历史行 + cf.write(raw if raw.endswith("\n") else raw + "\n") + logid = int(parts[0]) + adid = int(parts[2]) + if has_clk: + clk = int(parts[3]) + timestamp = int(parts[4]) + fs = 5 + else: + clk = 0 + timestamp = int(parts[3]) + fs = 4 + signs, slots = [], [] + for pair in parts[fs:]: + if ":" in pair: + s, sl = pair.split(":", 1) + signs.append(int(s)) + slots.append(int(sl)) + item_dict[logid] = { + "logid": logid, "userid": userid, "adid": adid, + "clk": clk, "timestamp": timestamp, + "signs": np.array(signs, dtype=np.int64), + "slots": np.array(slots, dtype=np.int64), + } + user_logs[userid].append((timestamp, logid)) + kept += 1 + print(f" {Path(fp).name}: has_clk={has_clk}, kept={kept}") + finally: + if cf is not None: + cf.flush() + os.fsync(cf.fileno()) + cf.close() user_seq = {} for u, logs in user_logs.items(): logs.sort(key=lambda x: x[0]) user_seq[u] = [lid for _, lid in logs] print(f"[BENCH] 过滤后:{len(item_dict)} 条记录,{len(user_seq)} 个用户") + if cache_csv_path: + print(f"[BENCH] 已缓存历史行 -> {cache_csv_path}(下次快速读取)") return item_dict, user_seq -def _cache_path(cur): - return cur / "bench_filtered_cache.pkl" - - -def _build_filtered(ref): - test_csv = ref / "test.csv" - history = ref / "history" - test_users = _test_user_ids(test_csv) - return _load_filtered(history, test_csv, test_users) - - -def _load_cache(cache): - import pickle - with open(cache, "rb") as f: - d = pickle.load(f) - return d["item_dict"], d["user_seq"] - - -def _save_cache(cache, item_dict, user_seq): - """原子写 + fsync + 写后校验;任何异常都不留毒文件。 - - 用 pickle 而非 torch.save:AI Studio overlay 文件系统对 torch 的 zip/mmap - 读取会间歇性报 [Errno 38]。pickle.dump 大对象较慢但顺序写更稳。 - """ - import pickle - try: - with open(cache, "wb") as f: - pickle.dump({"item_dict": item_dict, "user_seq": user_seq}, f, - protocol=pickle.HIGHEST_PROTOCOL) - f.flush() - os.fsync(f.fileno()) - print(f"[BENCH] 已缓存 -> {cache}(下次秒级读取;读不出会自动重建)") - except Exception as e: - print(f"[BENCH][WARN] 缓存写入失败({e}),本次不缓存(不影响结果)") - try: - os.remove(cache) - except OSError: - pass - - def _get_data(cur, ref, rebuild=False): - """取过滤后的 (item_dict, user_seq),优先读磁盘缓存。""" - cache = _cache_path(cur) - if cache.exists() and not rebuild: - print(f"[BENCH] 读取过滤缓存:{cache}") + """取过滤后的 (item_dict, user_seq),优先读 CSV 缓存。""" + cache_csv = cur / "cache_filtered_history.csv" + test_csv = ref / "test.csv" + if cache_csv.exists() and not rebuild: + print(f"[BENCH] 读取过滤缓存(CSV):{cache_csv}") try: - return _load_cache(cache) + return infer.load_sample_files([str(cache_csv), str(test_csv)]) except Exception as e: print(f"[BENCH][WARN] 缓存读取失败({e}),重新构建") - item_dict, user_seq = _build_filtered(ref) - _save_cache(cache, item_dict, user_seq) - return item_dict, user_seq + return _stream_build(ref, cache_csv_path=str(cache_csv)) + + +def run_diag(rebuild=False): + """诊断:测试用户序列长度分布 + sign-id 是否超界(判断上下文与 modulo 的价值)。""" + cur = Path(__file__).parent + ref = cur / "dataset" + item_dict, user_seq = _get_data(cur, ref, rebuild=rebuild) + lens = np.array([len(v) for v in user_seq.values()]) if user_seq else np.array([0]) + print(f"[DIAG] 测试用户数={len(user_seq)} 总记录数={len(item_dict)}") + print(f"[DIAG] 每用户序列长度 min/median/mean/max = " + f"{int(lens.min())}/{int(np.median(lens))}/{lens.mean():.1f}/{int(lens.max())}") + print(f"[DIAG] 序列长度>1 的用户占比 = {(lens > 1).mean():.1%}") + VOCAB = 5_000_000 + mx, over, tot = 0, 0, 0 + for rec in item_dict.values(): + s = rec["signs"] + if s.size: + m = int(s.max()) + if m > mx: + mx = m + over += int((s >= VOCAB).sum()) + tot += int(s.size) + print(f"[DIAG] max_sign_id={mx} vocab={VOCAB} " + f"超界sign占比={over}/{tot}={(over / max(tot, 1)):.2%}") def run_once(config_override=None, batch_size=50, max_batches=None, @@ -174,7 +176,6 @@ def run_once(config_override=None, batch_size=50, max_batches=None, test_csv = ref / "test.csv" label_file = ref / "label_data.txt" - # ----- 取数据(过滤+缓存)----- item_dict, user_seq = _get_data(cur, ref, rebuild=rebuild) test_logids = infer.load_logids_from_file(test_csv) ds = infer.CTRTestSeqDataset( @@ -191,15 +192,12 @@ def run_once(config_override=None, batch_size=50, max_batches=None, if max_batches is not None and len(batches) >= max_batches: break - # 释放构造期内存,降低推理峰值 del item_dict, user_seq, ds, loader import gc gc.collect() - # ----- 加载模型 ----- model, dev = infer.load_model(ckpt_path=None) - # ----- 推理 + 同步计时 ----- logid2p = {} t_sum = 0.0 cuda = (dev.type == "cuda") @@ -218,7 +216,6 @@ def run_once(config_override=None, batch_size=50, max_batches=None, for lid, p in zip(b["logid"][pm].cpu().tolist(), probs[pm].cpu().tolist()): logid2p[lid] = p - # ----- 按 test.csv 顺序写 predict.txt 并打分 ----- order = [int(l.split(",")[0]) for l in open(test_csv) if l.strip()] missing = [lid for lid in order if lid not in logid2p] if missing: @@ -238,54 +235,10 @@ def run_once(config_override=None, batch_size=50, max_batches=None, return res -def run_diag(rebuild=False): - """诊断:测试用户序列长度分布 + sign-id 是否超界(判断上下文与 modulo 的价值)。 - - 先打印诊断,再写缓存——避免缓存写入卡住时看不到诊断结果。 - """ - cur = Path(__file__).parent - ref = cur / "dataset" - cache = _cache_path(cur) - loaded = False - item_dict = user_seq = None - if cache.exists() and not rebuild: - print(f"[BENCH] 读取过滤缓存:{cache}") - try: - item_dict, user_seq = _load_cache(cache) - loaded = True - except Exception as e: - print(f"[BENCH][WARN] 缓存读取失败({e}),重新构建") - if not loaded: - item_dict, user_seq = _build_filtered(ref) - - lens = np.array([len(v) for v in user_seq.values()]) if user_seq else np.array([0]) - print(f"[DIAG] 测试用户数={len(user_seq)} 总记录数={len(item_dict)}") - print(f"[DIAG] 每用户序列长度 min/median/mean/max = " - f"{int(lens.min())}/{int(np.median(lens))}/{lens.mean():.1f}/{int(lens.max())}") - print(f"[DIAG] 序列长度>1 的用户占比 = {(lens > 1).mean():.1%} " - f"(占比低=大量测试样本没有历史上下文 → 生成式模型发挥不出来)") - VOCAB = 5_000_000 - mx, over, tot = 0, 0, 0 - for rec in item_dict.values(): - s = rec["signs"] - if s.size: - m = int(s.max()) - if m > mx: - mx = m - over += int((s >= VOCAB).sum()) - tot += int(s.size) - print(f"[DIAG] max_sign_id={mx} vocab={VOCAB} " - f"超界sign占比={over}/{tot}={(over / max(tot, 1)):.2%} " - f"(占比高=clamp 在污染 embedding → modulo 可能找回 AUC)") - - if not loaded: - _save_cache(_cache_path(cur), item_dict, user_seq) - - def _parse_args(): import argparse ap = argparse.ArgumentParser(description="CTI 推理测量闭环(子进程跑:!python bench.py ...)") - ap.add_argument("--diag", action="store_true", help="只跑诊断(序列长度分布 + sign-id 超界比例),不推理") + ap.add_argument("--diag", action="store_true", help="只跑诊断,不推理") ap.add_argument("--smoke", type=int, default=None, help="只跑前 N 个 batch(冒烟)") ap.add_argument("--bs", type=int, default=50, help="batch_size(本地参考)") ap.add_argument("--fp32", action="store_true", help="FP32 天花板 = 关 fp16 + 关 expert 合并")