Files
CTI-Inference-Opt/代码/code/bench.py
T
2026-06-14 19:46:21 +08:00

120 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""本地测量闭环:设置 infer.CONFIG,跑推理,同步计时,打印 AUC/PCOC/延迟/总分。
不进提交包。在 AI Studio notebook(带 dataset/ 与 ckpt.pt)里运行:
%cd /home/aistudio/code
!python bench.py # 默认配置基准
或在 notebook cell 里逐配置扫描:
import bench
bench.run_once({"fp16": False, "expert_merge": False}) # FP32 参考跑
bench.run_once({"signid_mode": "modulo"}) # 取模 vs clamp
"""
import os
import sys
import time
from pathlib import Path
# baseline 把依赖装在 --target 目录(非默认 site-packages),在 kernel 里 import
# 之前必须先把它加到 sys.path,否则 import torch 会 ModuleNotFoundError。
for _p in ("/home/aistudio/external-libraries", "/home/aistudio/libraries",
os.path.abspath("../libraries"), os.path.abspath("./libraries")):
if os.path.isdir(_p) and _p not in sys.path:
sys.path.insert(0, _p)
import torch
from torch.utils.data import DataLoader
import infer # 同目录
def run_once(config_override=None, batch_size=50, max_batches=None, max_feasign_per_slot=None):
"""跑一次本地推理并打分。
Args:
config_override: 覆盖 infer.CONFIG 的字典(如 {"fp16": False}
batch_size: DataLoader 的 batch 大小(本地参考;评测端可能自有设定)
max_batches: 只跑前 N 个 batch(快速冒烟用),None=全量
max_feasign_per_slot: 传给 CTRTestSeqDataset 的截断字典,None=不截断;
默认沿用 baseline 的 {1: 2}
Returns:
infer._cal_score 的结果 dict
"""
if config_override is None:
config_override = {}
if max_feasign_per_slot is None:
max_feasign_per_slot = {1: 2}
infer.CONFIG.update(config_override)
infer.CONFIG["sync_timing"] = True
cur = Path(__file__).parent
ref = cur / "dataset"
history = ref / "history"
test_csv = ref / "test.csv"
label_file = ref / "label_data.txt"
# ----- 加载数据 -----
files = (sorted(history.glob("*.csv")) if history.exists() else []) + [test_csv]
item_dict, user_seq = infer.load_sample_files(files)
test_logids = infer.load_logids_from_file(test_csv)
ds = infer.CTRTestSeqDataset(
test_logids_ordered=list(test_logids),
item_dict=item_dict,
user_seq=user_seq,
max_feasign_per_slot=max_feasign_per_slot,
max_ctx_len=None,
)
loader = DataLoader(
ds, batch_size=batch_size, shuffle=False, num_workers=0,
collate_fn=infer.make_collate_fn(ds.max_slot_id),
)
batches = []
for b in loader:
batches.append(infer.move_batch_to_device(b, torch.device("cpu")))
if max_batches is not None and len(batches) >= max_batches:
break
# ----- 加载模型 -----
model, dev = infer.load_model(ckpt_path=None)
# ----- 推理 + 同步计时 -----
logid2p = {}
t_sum = 0.0
cuda = (dev.type == "cuda")
with torch.inference_mode():
for b in batches:
b = infer.move_batch_to_device(b, dev)
pm = b["pred_mask"].bool()
if cuda:
torch.cuda.synchronize()
t0 = time.time()
logits, _ = model(b)
probs = torch.sigmoid(logits.squeeze(-1))
if cuda:
torch.cuda.synchronize()
t_sum += time.time() - t0
for lid, p in zip(b["logid"][pm].cpu().tolist(), probs[pm].cpu().tolist()):
logid2p[lid] = p
# ----- 按 test.csv 顺序写 predict.txt 并打分 -----
order = [int(l.split(",")[0]) for l in open(test_csv) if l.strip()]
pred_path = cur / "predict.txt"
with open(pred_path, "w") as f:
for lid in order:
f.write(f"{logid2p[lid]}\n")
res = infer._cal_score(pred_path, label_file, default_latency=t_sum)
print(
f"[BENCH] cfg={config_override} bs={batch_size}"
f"{'' if max_batches is None else f' (first {max_batches} batches)'}"
f" -> AUC={res['auc']:.5f} PCOC={res['pcoc']:.4f}"
f" lat={res['latency']:.2f}s score={res['score_all']:.2f}"
)
return res
if __name__ == "__main__":
run_once({}) # 默认配置基准