Files
CTI-Inference-Opt/代码/code/bench.py
T
2026-06-14 21:38:50 +08:00

269 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""本地测量闭环:设置 infer.CONFIG,跑推理,同步计时,打印 AUC/PCOC/延迟/总分。
不进提交包。**以子进程方式运行**(AI Studio 内核禁止 import torch):
%cd /home/aistudio/code
!python bench.py --smoke 50 # 冒烟:只跑前 50 batch
!python bench.py # 默认基线
!python bench.py --fp32 # FP32 天花板(Task 3
!python bench.py --rebuild # 强制重建过滤缓存
关键设计——只保留“测试用户”的数据:
不同用户被因果 mask 完全隔离,非测试用户的前向输出不参与打分;过滤掉它们
对测试样本的 AUC/PCOC 没有任何影响,却能把数据量从 924 万条降到一小部分,
避免 CTRTestSeqDataset 构造时 OOM。过滤后的数据缓存到磁盘,后续秒级复用。
"""
import os
import sys
import time
from collections import defaultdict
from pathlib import Path
# baseline 把依赖装在 --target 目录(非默认 site-packages),import 前先加 sys.path
for _p in ("/home/aistudio/external-libraries", "/home/aistudio/libraries",
os.path.abspath("../libraries"), os.path.abspath("./libraries")):
if os.path.isdir(_p) and _p not in sys.path:
sys.path.insert(0, _p)
import numpy as np
import torch
from torch.utils.data import DataLoader
import infer # 同目录
def _test_user_ids(test_csv):
"""从 test.csv 读出所有测试用户 id(第 2 列 userid)。"""
users = set()
with open(test_csv) as f:
for line in f:
line = line.strip()
if not line:
continue
parts = line.split(",")
if len(parts) >= 2:
users.add(int(parts[1]))
return users
def _load_filtered(history_dir, test_csv, test_users):
"""流式读取所有文件,只保留 userid ∈ test_users 的记录(不持有完整字典,防 OOM)。
解析逻辑与 infer.load_sample_files 完全一致,只是多了一道用户过滤。
"""
files = (sorted(history_dir.glob("*.csv")) if history_dir.exists() else []) + [test_csv]
print(f"[BENCH] 流式过滤加载 {len(files)} 个文件(仅保留 {len(test_users)} 个测试用户)...")
item_dict = {}
user_logs = defaultdict(list)
for fp in files:
has_clk = infer._detect_has_clk(fp)
min_parts = 5 if has_clk else 4
kept = 0
with open(fp) as f:
for line in f:
line = line.strip()
if not line:
continue
parts = line.split(",")
if len(parts) < min_parts:
continue
userid = int(parts[1])
if userid not in test_users:
continue
logid = int(parts[0])
adid = int(parts[2])
if has_clk:
clk = int(parts[3])
timestamp = int(parts[4])
fs = 5
else:
clk = 0
timestamp = int(parts[3])
fs = 4
signs, slots = [], []
for pair in parts[fs:]:
if ":" in pair:
s, sl = pair.split(":", 1)
signs.append(int(s))
slots.append(int(sl))
item_dict[logid] = {
"logid": logid, "userid": userid, "adid": adid,
"clk": clk, "timestamp": timestamp,
"signs": np.array(signs, dtype=np.int64),
"slots": np.array(slots, dtype=np.int64),
}
user_logs[userid].append((timestamp, logid))
kept += 1
print(f" {fp.name}: has_clk={has_clk}, kept={kept}")
user_seq = {}
for u, logs in user_logs.items():
logs.sort(key=lambda x: x[0])
user_seq[u] = [lid for _, lid in logs]
print(f"[BENCH] 过滤后:{len(item_dict)} 条记录,{len(user_seq)} 个用户")
return item_dict, user_seq
def _get_data(cur, ref, rebuild=False):
"""取过滤后的 (item_dict, user_seq),优先读磁盘缓存。"""
cache = cur / "bench_filtered_cache.pt"
test_csv = ref / "test.csv"
history = ref / "history"
if cache.exists() and not rebuild:
print(f"[BENCH] 读取过滤缓存:{cache}")
d = torch.load(cache, weights_only=False)
return d["item_dict"], d["user_seq"]
test_users = _test_user_ids(test_csv)
item_dict, user_seq = _load_filtered(history, test_csv, test_users)
torch.save({"item_dict": item_dict, "user_seq": user_seq}, cache)
print(f"[BENCH] 已缓存 -> {cache}")
return item_dict, user_seq
def run_once(config_override=None, batch_size=50, max_batches=None,
max_feasign_per_slot=None, rebuild=False):
"""跑一次本地推理并打分。返回 infer._cal_score 的结果 dict。"""
if config_override is None:
config_override = {}
if max_feasign_per_slot is None:
max_feasign_per_slot = {1: 2}
infer.CONFIG.update(config_override)
infer.CONFIG["sync_timing"] = True
cur = Path(__file__).parent
ref = cur / "dataset"
test_csv = ref / "test.csv"
label_file = ref / "label_data.txt"
# ----- 取数据(过滤+缓存)-----
item_dict, user_seq = _get_data(cur, ref, rebuild=rebuild)
test_logids = infer.load_logids_from_file(test_csv)
ds = infer.CTRTestSeqDataset(
test_logids_ordered=list(test_logids), item_dict=item_dict,
user_seq=user_seq, max_feasign_per_slot=max_feasign_per_slot, max_ctx_len=None,
)
loader = DataLoader(
ds, batch_size=batch_size, shuffle=False, num_workers=0,
collate_fn=infer.make_collate_fn(ds.max_slot_id),
)
batches = []
for b in loader:
batches.append(infer.move_batch_to_device(b, torch.device("cpu")))
if max_batches is not None and len(batches) >= max_batches:
break
# 释放构造期内存,降低推理峰值
del item_dict, user_seq, ds, loader
import gc
gc.collect()
# ----- 加载模型 -----
model, dev = infer.load_model(ckpt_path=None)
# ----- 推理 + 同步计时 -----
logid2p = {}
t_sum = 0.0
cuda = (dev.type == "cuda")
with torch.inference_mode():
for b in batches:
b = infer.move_batch_to_device(b, dev)
pm = b["pred_mask"].bool()
if cuda:
torch.cuda.synchronize()
t0 = time.time()
logits, _ = model(b)
probs = torch.sigmoid(logits.squeeze(-1))
if cuda:
torch.cuda.synchronize()
t_sum += time.time() - t0
for lid, p in zip(b["logid"][pm].cpu().tolist(), probs[pm].cpu().tolist()):
logid2p[lid] = p
# ----- 按 test.csv 顺序写 predict.txt 并打分 -----
order = [int(l.split(",")[0]) for l in open(test_csv) if l.strip()]
missing = [lid for lid in order if lid not in logid2p]
if missing:
print(f"[BENCH][WARN] {len(missing)} 个测试 logid 没预测到(前几个 {missing[:5]}")
pred_path = cur / "predict.txt"
with open(pred_path, "w") as f:
for lid in order:
f.write(f"{logid2p.get(lid, 0.0)}\n")
res = infer._cal_score(pred_path, label_file, default_latency=t_sum)
print(
f"[BENCH] cfg={config_override} bs={batch_size}"
f"{'' if max_batches is None else f' (first {max_batches} batches)'}"
f" -> AUC={res['auc']:.5f} PCOC={res['pcoc']:.4f}"
f" lat={res['latency']:.2f}s score={res['score_all']:.2f}"
)
return res
def run_diag(rebuild=False):
"""诊断:测试用户序列长度分布 + sign-id 是否超界(判断上下文与 modulo 的价值)。"""
cur = Path(__file__).parent
ref = cur / "dataset"
item_dict, user_seq = _get_data(cur, ref, rebuild=rebuild)
lens = np.array([len(v) for v in user_seq.values()]) if user_seq else np.array([0])
print(f"[DIAG] 测试用户数={len(user_seq)} 总记录数={len(item_dict)}")
print(f"[DIAG] 每用户序列长度 min/median/mean/max = "
f"{int(lens.min())}/{int(np.median(lens))}/{lens.mean():.1f}/{int(lens.max())}")
print(f"[DIAG] 序列长度>1 的用户占比 = {(lens > 1).mean():.1%} "
f"(占比低=大量测试样本没有历史上下文 → 生成式模型发挥不出来)")
VOCAB = 5_000_000
mx, over, tot = 0, 0, 0
for rec in item_dict.values():
s = rec["signs"]
if s.size:
m = int(s.max())
if m > mx:
mx = m
over += int((s >= VOCAB).sum())
tot += int(s.size)
print(f"[DIAG] max_sign_id={mx} vocab={VOCAB} "
f"超界sign占比={over}/{tot}={(over / max(tot, 1)):.2%} "
f"(占比高=clamp 在污染 embedding → modulo 可能找回 AUC")
def _parse_args():
import argparse
ap = argparse.ArgumentParser(description="CTI 推理测量闭环(子进程跑:!python bench.py ...")
ap.add_argument("--diag", action="store_true", help="只跑诊断(序列长度分布 + sign-id 超界比例),不推理")
ap.add_argument("--smoke", type=int, default=None, help="只跑前 N 个 batch(冒烟)")
ap.add_argument("--bs", type=int, default=50, help="batch_size(本地参考)")
ap.add_argument("--fp32", action="store_true", help="FP32 天花板 = 关 fp16 + 关 expert 合并")
ap.add_argument("--no-fp16", action="store_true", help="关闭半精度")
ap.add_argument("--no-merge", action="store_true", help="关闭 expert 合并")
ap.add_argument("--signid", choices=["clamp", "modulo"], default=None, help="sign-id 处理方式")
ap.add_argument("--merge-th", type=float, default=None, help="expert 合并余弦阈值")
ap.add_argument("--keep", type=str, default=None,
help="逗号分隔的 keep_fp32_modules,如 linear,rep_encoder.input_norm")
ap.add_argument("--feasign-none", action="store_true",
help="不截断特征(max_feasign_per_slot=None")
ap.add_argument("--rebuild", action="store_true", help="强制重建过滤缓存")
return ap.parse_args()
if __name__ == "__main__":
a = _parse_args()
if a.diag:
run_diag(rebuild=a.rebuild)
sys.exit(0)
cfg = {}
if a.fp32:
cfg["fp16"] = False
cfg["expert_merge"] = False
if a.no_fp16:
cfg["fp16"] = False
if a.no_merge:
cfg["expert_merge"] = False
if a.signid:
cfg["signid_mode"] = a.signid
if a.merge_th is not None:
cfg["merge_threshold"] = a.merge_th
if a.keep is not None:
cfg["keep_fp32_modules"] = tuple(x for x in a.keep.split(",") if x)
mf = None if a.feasign_none else {1: 2}
run_once(cfg, batch_size=a.bs, max_batches=a.smoke, max_feasign_per_slot=mf, rebuild=a.rebuild)