feat/auc-recovery-plan #1
+104
-151
@@ -3,15 +3,19 @@
|
|||||||
不进提交包。**以子进程方式运行**(AI Studio 内核禁止 import torch):
|
不进提交包。**以子进程方式运行**(AI Studio 内核禁止 import torch):
|
||||||
|
|
||||||
%cd /home/aistudio/code
|
%cd /home/aistudio/code
|
||||||
!python bench.py --smoke 50 # 冒烟:只跑前 50 batch
|
!python bench.py --diag # 诊断:序列长度分布 + sign-id 超界比例
|
||||||
!python bench.py # 默认基线
|
!python bench.py --smoke 50 # 冒烟:只跑前 50 batch
|
||||||
!python bench.py --fp32 # FP32 天花板(Task 3)
|
!python bench.py # 默认基线
|
||||||
!python bench.py --rebuild # 强制重建过滤缓存
|
!python bench.py --fp32 # FP32 天花板
|
||||||
|
!python bench.py --rebuild # 强制重建过滤缓存
|
||||||
|
|
||||||
关键设计——只保留“测试用户”的数据:
|
只保留“测试用户”的数据:不同用户被因果 mask 完全隔离,非测试用户的前向输出
|
||||||
不同用户被因果 mask 完全隔离,非测试用户的前向输出不参与打分;过滤掉它们
|
不参与打分;过滤掉它们对测试样本的 AUC/PCOC 没有任何影响,却能把数据量从
|
||||||
对测试样本的 AUC/PCOC 没有任何影响,却能把数据量从 924 万条降到一小部分,
|
924 万条降到一小部分。
|
||||||
避免 CTRTestSeqDataset 构造时 OOM。过滤后的数据缓存到磁盘,后续秒级复用。
|
|
||||||
|
缓存用**文本 CSV**而非 pickle:容器 cgroup 内存有限,pickle.dump 大对象的 memo
|
||||||
|
会瞬间撑爆内存被静默 OOM-kill;逐行写 CSV 内存几乎不涨,再用 load_sample_files
|
||||||
|
读回,稳。
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@@ -46,116 +50,114 @@ def _test_user_ids(test_csv):
|
|||||||
return users
|
return users
|
||||||
|
|
||||||
|
|
||||||
def _load_filtered(history_dir, test_csv, test_users):
|
def _stream_build(ref, cache_csv_path=None):
|
||||||
"""流式读取所有文件,只保留 userid ∈ test_users 的记录(不持有完整字典,防 OOM)。
|
"""流式过滤:构建 item_dict/user_seq;若给 cache_csv_path,同时把保留的历史行
|
||||||
|
原样逐行写入(低内存文本缓存,test.csv 直接复用、不进缓存)。
|
||||||
解析逻辑与 infer.load_sample_files 完全一致,只是多了一道用户过滤。
|
|
||||||
"""
|
"""
|
||||||
files = (sorted(history_dir.glob("*.csv")) if history_dir.exists() else []) + [test_csv]
|
test_csv = ref / "test.csv"
|
||||||
|
history = ref / "history"
|
||||||
|
test_users = _test_user_ids(test_csv)
|
||||||
|
files = (sorted(history.glob("*.csv")) if history.exists() else []) + [test_csv]
|
||||||
print(f"[BENCH] 流式过滤加载 {len(files)} 个文件(仅保留 {len(test_users)} 个测试用户)...")
|
print(f"[BENCH] 流式过滤加载 {len(files)} 个文件(仅保留 {len(test_users)} 个测试用户)...")
|
||||||
|
|
||||||
item_dict = {}
|
item_dict = {}
|
||||||
user_logs = defaultdict(list)
|
user_logs = defaultdict(list)
|
||||||
for fp in files:
|
cf = open(cache_csv_path, "w") if cache_csv_path else None
|
||||||
has_clk = infer._detect_has_clk(fp)
|
try:
|
||||||
min_parts = 5 if has_clk else 4
|
for fp in files:
|
||||||
kept = 0
|
has_clk = infer._detect_has_clk(fp)
|
||||||
with open(fp) as f:
|
min_parts = 5 if has_clk else 4
|
||||||
for line in f:
|
is_test = (Path(fp).name == test_csv.name)
|
||||||
line = line.strip()
|
kept = 0
|
||||||
if not line:
|
with open(fp) as f:
|
||||||
continue
|
for raw in f:
|
||||||
parts = line.split(",")
|
line = raw.strip()
|
||||||
if len(parts) < min_parts:
|
if not line:
|
||||||
continue
|
continue
|
||||||
userid = int(parts[1])
|
parts = line.split(",")
|
||||||
if userid not in test_users:
|
if len(parts) < min_parts:
|
||||||
continue
|
continue
|
||||||
logid = int(parts[0])
|
userid = int(parts[1])
|
||||||
adid = int(parts[2])
|
if userid not in test_users:
|
||||||
if has_clk:
|
continue
|
||||||
clk = int(parts[3])
|
if cf is not None and not is_test: # 只缓存历史行
|
||||||
timestamp = int(parts[4])
|
cf.write(raw if raw.endswith("\n") else raw + "\n")
|
||||||
fs = 5
|
logid = int(parts[0])
|
||||||
else:
|
adid = int(parts[2])
|
||||||
clk = 0
|
if has_clk:
|
||||||
timestamp = int(parts[3])
|
clk = int(parts[3])
|
||||||
fs = 4
|
timestamp = int(parts[4])
|
||||||
signs, slots = [], []
|
fs = 5
|
||||||
for pair in parts[fs:]:
|
else:
|
||||||
if ":" in pair:
|
clk = 0
|
||||||
s, sl = pair.split(":", 1)
|
timestamp = int(parts[3])
|
||||||
signs.append(int(s))
|
fs = 4
|
||||||
slots.append(int(sl))
|
signs, slots = [], []
|
||||||
item_dict[logid] = {
|
for pair in parts[fs:]:
|
||||||
"logid": logid, "userid": userid, "adid": adid,
|
if ":" in pair:
|
||||||
"clk": clk, "timestamp": timestamp,
|
s, sl = pair.split(":", 1)
|
||||||
"signs": np.array(signs, dtype=np.int64),
|
signs.append(int(s))
|
||||||
"slots": np.array(slots, dtype=np.int64),
|
slots.append(int(sl))
|
||||||
}
|
item_dict[logid] = {
|
||||||
user_logs[userid].append((timestamp, logid))
|
"logid": logid, "userid": userid, "adid": adid,
|
||||||
kept += 1
|
"clk": clk, "timestamp": timestamp,
|
||||||
print(f" {fp.name}: has_clk={has_clk}, kept={kept}")
|
"signs": np.array(signs, dtype=np.int64),
|
||||||
|
"slots": np.array(slots, dtype=np.int64),
|
||||||
|
}
|
||||||
|
user_logs[userid].append((timestamp, logid))
|
||||||
|
kept += 1
|
||||||
|
print(f" {Path(fp).name}: has_clk={has_clk}, kept={kept}")
|
||||||
|
finally:
|
||||||
|
if cf is not None:
|
||||||
|
cf.flush()
|
||||||
|
os.fsync(cf.fileno())
|
||||||
|
cf.close()
|
||||||
|
|
||||||
user_seq = {}
|
user_seq = {}
|
||||||
for u, logs in user_logs.items():
|
for u, logs in user_logs.items():
|
||||||
logs.sort(key=lambda x: x[0])
|
logs.sort(key=lambda x: x[0])
|
||||||
user_seq[u] = [lid for _, lid in logs]
|
user_seq[u] = [lid for _, lid in logs]
|
||||||
print(f"[BENCH] 过滤后:{len(item_dict)} 条记录,{len(user_seq)} 个用户")
|
print(f"[BENCH] 过滤后:{len(item_dict)} 条记录,{len(user_seq)} 个用户")
|
||||||
|
if cache_csv_path:
|
||||||
|
print(f"[BENCH] 已缓存历史行 -> {cache_csv_path}(下次快速读取)")
|
||||||
return item_dict, user_seq
|
return item_dict, user_seq
|
||||||
|
|
||||||
|
|
||||||
def _cache_path(cur):
|
|
||||||
return cur / "bench_filtered_cache.pkl"
|
|
||||||
|
|
||||||
|
|
||||||
def _build_filtered(ref):
|
|
||||||
test_csv = ref / "test.csv"
|
|
||||||
history = ref / "history"
|
|
||||||
test_users = _test_user_ids(test_csv)
|
|
||||||
return _load_filtered(history, test_csv, test_users)
|
|
||||||
|
|
||||||
|
|
||||||
def _load_cache(cache):
|
|
||||||
import pickle
|
|
||||||
with open(cache, "rb") as f:
|
|
||||||
d = pickle.load(f)
|
|
||||||
return d["item_dict"], d["user_seq"]
|
|
||||||
|
|
||||||
|
|
||||||
def _save_cache(cache, item_dict, user_seq):
|
|
||||||
"""原子写 + fsync + 写后校验;任何异常都不留毒文件。
|
|
||||||
|
|
||||||
用 pickle 而非 torch.save:AI Studio overlay 文件系统对 torch 的 zip/mmap
|
|
||||||
读取会间歇性报 [Errno 38]。pickle.dump 大对象较慢但顺序写更稳。
|
|
||||||
"""
|
|
||||||
import pickle
|
|
||||||
try:
|
|
||||||
with open(cache, "wb") as f:
|
|
||||||
pickle.dump({"item_dict": item_dict, "user_seq": user_seq}, f,
|
|
||||||
protocol=pickle.HIGHEST_PROTOCOL)
|
|
||||||
f.flush()
|
|
||||||
os.fsync(f.fileno())
|
|
||||||
print(f"[BENCH] 已缓存 -> {cache}(下次秒级读取;读不出会自动重建)")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[BENCH][WARN] 缓存写入失败({e}),本次不缓存(不影响结果)")
|
|
||||||
try:
|
|
||||||
os.remove(cache)
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def _get_data(cur, ref, rebuild=False):
|
def _get_data(cur, ref, rebuild=False):
|
||||||
"""取过滤后的 (item_dict, user_seq),优先读磁盘缓存。"""
|
"""取过滤后的 (item_dict, user_seq),优先读 CSV 缓存。"""
|
||||||
cache = _cache_path(cur)
|
cache_csv = cur / "cache_filtered_history.csv"
|
||||||
if cache.exists() and not rebuild:
|
test_csv = ref / "test.csv"
|
||||||
print(f"[BENCH] 读取过滤缓存:{cache}")
|
if cache_csv.exists() and not rebuild:
|
||||||
|
print(f"[BENCH] 读取过滤缓存(CSV):{cache_csv}")
|
||||||
try:
|
try:
|
||||||
return _load_cache(cache)
|
return infer.load_sample_files([str(cache_csv), str(test_csv)])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[BENCH][WARN] 缓存读取失败({e}),重新构建")
|
print(f"[BENCH][WARN] 缓存读取失败({e}),重新构建")
|
||||||
item_dict, user_seq = _build_filtered(ref)
|
return _stream_build(ref, cache_csv_path=str(cache_csv))
|
||||||
_save_cache(cache, item_dict, user_seq)
|
|
||||||
return item_dict, user_seq
|
|
||||||
|
def run_diag(rebuild=False):
|
||||||
|
"""诊断:测试用户序列长度分布 + sign-id 是否超界(判断上下文与 modulo 的价值)。"""
|
||||||
|
cur = Path(__file__).parent
|
||||||
|
ref = cur / "dataset"
|
||||||
|
item_dict, user_seq = _get_data(cur, ref, rebuild=rebuild)
|
||||||
|
lens = np.array([len(v) for v in user_seq.values()]) if user_seq else np.array([0])
|
||||||
|
print(f"[DIAG] 测试用户数={len(user_seq)} 总记录数={len(item_dict)}")
|
||||||
|
print(f"[DIAG] 每用户序列长度 min/median/mean/max = "
|
||||||
|
f"{int(lens.min())}/{int(np.median(lens))}/{lens.mean():.1f}/{int(lens.max())}")
|
||||||
|
print(f"[DIAG] 序列长度>1 的用户占比 = {(lens > 1).mean():.1%}")
|
||||||
|
VOCAB = 5_000_000
|
||||||
|
mx, over, tot = 0, 0, 0
|
||||||
|
for rec in item_dict.values():
|
||||||
|
s = rec["signs"]
|
||||||
|
if s.size:
|
||||||
|
m = int(s.max())
|
||||||
|
if m > mx:
|
||||||
|
mx = m
|
||||||
|
over += int((s >= VOCAB).sum())
|
||||||
|
tot += int(s.size)
|
||||||
|
print(f"[DIAG] max_sign_id={mx} vocab={VOCAB} "
|
||||||
|
f"超界sign占比={over}/{tot}={(over / max(tot, 1)):.2%}")
|
||||||
|
|
||||||
|
|
||||||
def run_once(config_override=None, batch_size=50, max_batches=None,
|
def run_once(config_override=None, batch_size=50, max_batches=None,
|
||||||
@@ -174,7 +176,6 @@ def run_once(config_override=None, batch_size=50, max_batches=None,
|
|||||||
test_csv = ref / "test.csv"
|
test_csv = ref / "test.csv"
|
||||||
label_file = ref / "label_data.txt"
|
label_file = ref / "label_data.txt"
|
||||||
|
|
||||||
# ----- 取数据(过滤+缓存)-----
|
|
||||||
item_dict, user_seq = _get_data(cur, ref, rebuild=rebuild)
|
item_dict, user_seq = _get_data(cur, ref, rebuild=rebuild)
|
||||||
test_logids = infer.load_logids_from_file(test_csv)
|
test_logids = infer.load_logids_from_file(test_csv)
|
||||||
ds = infer.CTRTestSeqDataset(
|
ds = infer.CTRTestSeqDataset(
|
||||||
@@ -191,15 +192,12 @@ def run_once(config_override=None, batch_size=50, max_batches=None,
|
|||||||
if max_batches is not None and len(batches) >= max_batches:
|
if max_batches is not None and len(batches) >= max_batches:
|
||||||
break
|
break
|
||||||
|
|
||||||
# 释放构造期内存,降低推理峰值
|
|
||||||
del item_dict, user_seq, ds, loader
|
del item_dict, user_seq, ds, loader
|
||||||
import gc
|
import gc
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
# ----- 加载模型 -----
|
|
||||||
model, dev = infer.load_model(ckpt_path=None)
|
model, dev = infer.load_model(ckpt_path=None)
|
||||||
|
|
||||||
# ----- 推理 + 同步计时 -----
|
|
||||||
logid2p = {}
|
logid2p = {}
|
||||||
t_sum = 0.0
|
t_sum = 0.0
|
||||||
cuda = (dev.type == "cuda")
|
cuda = (dev.type == "cuda")
|
||||||
@@ -218,7 +216,6 @@ def run_once(config_override=None, batch_size=50, max_batches=None,
|
|||||||
for lid, p in zip(b["logid"][pm].cpu().tolist(), probs[pm].cpu().tolist()):
|
for lid, p in zip(b["logid"][pm].cpu().tolist(), probs[pm].cpu().tolist()):
|
||||||
logid2p[lid] = p
|
logid2p[lid] = p
|
||||||
|
|
||||||
# ----- 按 test.csv 顺序写 predict.txt 并打分 -----
|
|
||||||
order = [int(l.split(",")[0]) for l in open(test_csv) if l.strip()]
|
order = [int(l.split(",")[0]) for l in open(test_csv) if l.strip()]
|
||||||
missing = [lid for lid in order if lid not in logid2p]
|
missing = [lid for lid in order if lid not in logid2p]
|
||||||
if missing:
|
if missing:
|
||||||
@@ -238,54 +235,10 @@ def run_once(config_override=None, batch_size=50, max_batches=None,
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def run_diag(rebuild=False):
|
|
||||||
"""诊断:测试用户序列长度分布 + sign-id 是否超界(判断上下文与 modulo 的价值)。
|
|
||||||
|
|
||||||
先打印诊断,再写缓存——避免缓存写入卡住时看不到诊断结果。
|
|
||||||
"""
|
|
||||||
cur = Path(__file__).parent
|
|
||||||
ref = cur / "dataset"
|
|
||||||
cache = _cache_path(cur)
|
|
||||||
loaded = False
|
|
||||||
item_dict = user_seq = None
|
|
||||||
if cache.exists() and not rebuild:
|
|
||||||
print(f"[BENCH] 读取过滤缓存:{cache}")
|
|
||||||
try:
|
|
||||||
item_dict, user_seq = _load_cache(cache)
|
|
||||||
loaded = True
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[BENCH][WARN] 缓存读取失败({e}),重新构建")
|
|
||||||
if not loaded:
|
|
||||||
item_dict, user_seq = _build_filtered(ref)
|
|
||||||
|
|
||||||
lens = np.array([len(v) for v in user_seq.values()]) if user_seq else np.array([0])
|
|
||||||
print(f"[DIAG] 测试用户数={len(user_seq)} 总记录数={len(item_dict)}")
|
|
||||||
print(f"[DIAG] 每用户序列长度 min/median/mean/max = "
|
|
||||||
f"{int(lens.min())}/{int(np.median(lens))}/{lens.mean():.1f}/{int(lens.max())}")
|
|
||||||
print(f"[DIAG] 序列长度>1 的用户占比 = {(lens > 1).mean():.1%} "
|
|
||||||
f"(占比低=大量测试样本没有历史上下文 → 生成式模型发挥不出来)")
|
|
||||||
VOCAB = 5_000_000
|
|
||||||
mx, over, tot = 0, 0, 0
|
|
||||||
for rec in item_dict.values():
|
|
||||||
s = rec["signs"]
|
|
||||||
if s.size:
|
|
||||||
m = int(s.max())
|
|
||||||
if m > mx:
|
|
||||||
mx = m
|
|
||||||
over += int((s >= VOCAB).sum())
|
|
||||||
tot += int(s.size)
|
|
||||||
print(f"[DIAG] max_sign_id={mx} vocab={VOCAB} "
|
|
||||||
f"超界sign占比={over}/{tot}={(over / max(tot, 1)):.2%} "
|
|
||||||
f"(占比高=clamp 在污染 embedding → modulo 可能找回 AUC)")
|
|
||||||
|
|
||||||
if not loaded:
|
|
||||||
_save_cache(_cache_path(cur), item_dict, user_seq)
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_args():
|
def _parse_args():
|
||||||
import argparse
|
import argparse
|
||||||
ap = argparse.ArgumentParser(description="CTI 推理测量闭环(子进程跑:!python bench.py ...)")
|
ap = argparse.ArgumentParser(description="CTI 推理测量闭环(子进程跑:!python bench.py ...)")
|
||||||
ap.add_argument("--diag", action="store_true", help="只跑诊断(序列长度分布 + sign-id 超界比例),不推理")
|
ap.add_argument("--diag", action="store_true", help="只跑诊断,不推理")
|
||||||
ap.add_argument("--smoke", type=int, default=None, help="只跑前 N 个 batch(冒烟)")
|
ap.add_argument("--smoke", type=int, default=None, help="只跑前 N 个 batch(冒烟)")
|
||||||
ap.add_argument("--bs", type=int, default=50, help="batch_size(本地参考)")
|
ap.add_argument("--bs", type=int, default=50, help="batch_size(本地参考)")
|
||||||
ap.add_argument("--fp32", action="store_true", help="FP32 天花板 = 关 fp16 + 关 expert 合并")
|
ap.add_argument("--fp32", action="store_true", help="FP32 天花板 = 关 fp16 + 关 expert 合并")
|
||||||
|
|||||||
Reference in New Issue
Block a user