Files
CTI-Inference-Opt/代码/code/bench.py
T
OwnerSunshine530 6625666010 feat: sparse_pool 选项 — (段×唯一)稀疏矩阵乘做池化,避免materialize[M,emb]
针对 profile 的 dedup展开(15%)+segment_reduce(6.6%)。段内高重复(slot19)塌缩
为单个带权项。CONFIG.sparse_pool;bench --sparse-pool;等价测试已加。默认关,待验证。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 15:15:13 +08:00

347 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""本地测量闭环:设置 infer.CONFIG,跑推理,同步计时,打印 AUC/PCOC/延迟/总分。
不进提交包。**以子进程方式运行**(AI Studio 内核禁止 import torch):
%cd /home/aistudio/code
!python bench.py --diag # 诊断:序列长度分布 + sign-id 超界比例
!python bench.py --smoke 50 # 冒烟:只跑前 50 batch
!python bench.py # 默认基线
!python bench.py --fp32 # FP32 天花板
!python bench.py --rebuild # 强制重建过滤缓存
只保留“测试用户”的数据:不同用户被因果 mask 完全隔离,非测试用户的前向输出
不参与打分;过滤掉它们对测试样本的 AUC/PCOC 没有任何影响,却能把数据量从
924 万条降到一小部分。
缓存用**文本 CSV**而非 pickle:容器 cgroup 内存有限,pickle.dump 大对象的 memo
会瞬间撑爆内存被静默 OOM-kill;逐行写 CSV 内存几乎不涨,再用 load_sample_files
读回,稳。
"""
import os
import sys
import time
from collections import defaultdict
from pathlib import Path
# baseline 把依赖装在 --target 目录(非默认 site-packages),import 前先加 sys.path
for _p in ("/home/aistudio/external-libraries", "/home/aistudio/libraries",
os.path.abspath("../libraries"), os.path.abspath("./libraries")):
if os.path.isdir(_p) and _p not in sys.path:
sys.path.insert(0, _p)
import numpy as np
import torch
from torch.utils.data import DataLoader
import infer # 同目录
def _test_user_ids(test_csv):
"""从 test.csv 读出所有测试用户 id(第 2 列 userid)。"""
users = set()
with open(test_csv) as f:
for line in f:
line = line.strip()
if not line:
continue
parts = line.split(",")
if len(parts) >= 2:
users.add(int(parts[1]))
return users
def _stream_build(ref, cache_csv_path=None):
"""流式过滤:构建 item_dict/user_seq;若给 cache_csv_path,同时把保留的历史行
原样逐行写入(低内存文本缓存,test.csv 直接复用、不进缓存)。
"""
test_csv = ref / "test.csv"
history = ref / "history"
test_users = _test_user_ids(test_csv)
files = (sorted(history.glob("*.csv")) if history.exists() else []) + [test_csv]
print(f"[BENCH] 流式过滤加载 {len(files)} 个文件(仅保留 {len(test_users)} 个测试用户)...")
item_dict = {}
user_logs = defaultdict(list)
cf = open(cache_csv_path, "w") if cache_csv_path else None
try:
for fp in files:
has_clk = infer._detect_has_clk(fp)
min_parts = 5 if has_clk else 4
is_test = (Path(fp).name == test_csv.name)
kept = 0
with open(fp) as f:
for raw in f:
line = raw.strip()
if not line:
continue
parts = line.split(",")
if len(parts) < min_parts:
continue
userid = int(parts[1])
if userid not in test_users:
continue
if cf is not None and not is_test: # 只缓存历史行
cf.write(raw if raw.endswith("\n") else raw + "\n")
logid = int(parts[0])
adid = int(parts[2])
if has_clk:
clk = int(parts[3])
timestamp = int(parts[4])
fs = 5
else:
clk = 0
timestamp = int(parts[3])
fs = 4
signs, slots = [], []
for pair in parts[fs:]:
if ":" in pair:
s, sl = pair.split(":", 1)
signs.append(int(s))
slots.append(int(sl))
item_dict[logid] = {
"logid": logid, "userid": userid, "adid": adid,
"clk": clk, "timestamp": timestamp,
"signs": np.array(signs, dtype=np.int64),
"slots": np.array(slots, dtype=np.int64),
}
user_logs[userid].append((timestamp, logid))
kept += 1
print(f" {Path(fp).name}: has_clk={has_clk}, kept={kept}")
finally:
if cf is not None:
cf.flush()
os.fsync(cf.fileno())
cf.close()
user_seq = {}
for u, logs in user_logs.items():
logs.sort(key=lambda x: x[0])
user_seq[u] = [lid for _, lid in logs]
print(f"[BENCH] 过滤后:{len(item_dict)} 条记录,{len(user_seq)} 个用户")
if cache_csv_path:
print(f"[BENCH] 已缓存历史行 -> {cache_csv_path}(下次快速读取)")
return item_dict, user_seq
def _get_data(cur, ref, rebuild=False):
"""取过滤后的 (item_dict, user_seq),优先读 CSV 缓存。"""
cache_csv = cur / "cache_filtered_history.csv"
test_csv = ref / "test.csv"
if cache_csv.exists() and not rebuild:
print(f"[BENCH] 读取过滤缓存(CSV){cache_csv}")
try:
return infer.load_sample_files([str(cache_csv), str(test_csv)])
except Exception as e:
print(f"[BENCH][WARN] 缓存读取失败({e}),重新构建")
return _stream_build(ref, cache_csv_path=str(cache_csv))
def run_diag(rebuild=False):
"""诊断:测试用户序列长度分布 + sign-id 是否超界(判断上下文与 modulo 的价值)。"""
cur = Path(__file__).parent
ref = cur / "dataset"
item_dict, user_seq = _get_data(cur, ref, rebuild=rebuild)
lens = np.array([len(v) for v in user_seq.values()]) if user_seq else np.array([0])
print(f"[DIAG] 测试用户数={len(user_seq)} 总记录数={len(item_dict)}")
print(f"[DIAG] 每用户序列长度 min/median/mean/max = "
f"{int(lens.min())}/{int(np.median(lens))}/{lens.mean():.1f}/{int(lens.max())}")
print(f"[DIAG] 序列长度>1 的用户占比 = {(lens > 1).mean():.1%}")
VOCAB = 5_000_000
mx, over, tot = 0, 0, 0
for rec in item_dict.values():
s = rec["signs"]
if s.size:
m = int(s.max())
if m > mx:
mx = m
over += int((s >= VOCAB).sum())
tot += int(s.size)
print(f"[DIAG] max_sign_id={mx} vocab={VOCAB} "
f"超界sign占比={over}/{tot}={(over / max(tot, 1)):.2%}")
def run_profile(config_override=None, n=20, batch_size=50, rebuild=False):
"""用 torch.profiler 剖析前 n 个 batch,打印按 CUDA 耗时排序的算子表,定位真正瓶颈。"""
if config_override is None:
config_override = {}
infer.CONFIG.update(config_override)
cur = Path(__file__).parent
ref = cur / "dataset"
item_dict, user_seq = _get_data(cur, ref, rebuild=rebuild)
test_logids = infer.load_logids_from_file(ref / "test.csv")
ds = infer.CTRTestSeqDataset(
test_logids_ordered=list(test_logids), item_dict=item_dict,
user_seq=user_seq, max_feasign_per_slot={1: 2}, max_ctx_len=None)
loader = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0,
collate_fn=infer.make_collate_fn(ds.max_slot_id))
batches = []
for b in loader:
batches.append(infer.move_batch_to_device(b, torch.device("cpu")))
if len(batches) >= n:
break
del item_dict, user_seq, ds, loader
import gc
gc.collect()
model, dev = infer.load_model(ckpt_path=None)
cuda = (dev.type == "cuda")
from torch.profiler import profile, ProfilerActivity
acts = [ProfilerActivity.CPU] + ([ProfilerActivity.CUDA] if cuda else [])
with torch.inference_mode():
warm = infer.move_batch_to_device(batches[0], dev) # 预热(触发任何首次编译)
model(warm)
if cuda:
torch.cuda.synchronize()
with profile(activities=acts) as prof:
for b in batches:
b = infer.move_batch_to_device(b, dev)
model(b)
if cuda:
torch.cuda.synchronize()
sort_key = "cuda_time_total" if cuda else "cpu_time_total"
print(prof.key_averages().table(sort_by=sort_key, row_limit=25))
def run_once(config_override=None, batch_size=50, max_batches=None,
max_feasign_per_slot=None, rebuild=False):
"""跑一次本地推理并打分。返回 infer._cal_score 的结果 dict。"""
if config_override is None:
config_override = {}
if max_feasign_per_slot is None:
max_feasign_per_slot = {1: 2}
infer.CONFIG.update(config_override)
infer.CONFIG["sync_timing"] = True
cur = Path(__file__).parent
ref = cur / "dataset"
test_csv = ref / "test.csv"
label_file = ref / "label_data.txt"
item_dict, user_seq = _get_data(cur, ref, rebuild=rebuild)
test_logids = infer.load_logids_from_file(test_csv)
ds = infer.CTRTestSeqDataset(
test_logids_ordered=list(test_logids), item_dict=item_dict,
user_seq=user_seq, max_feasign_per_slot=max_feasign_per_slot, max_ctx_len=None,
)
loader = DataLoader(
ds, batch_size=batch_size, shuffle=False, num_workers=0,
collate_fn=infer.make_collate_fn(ds.max_slot_id),
)
batches = []
for b in loader:
batches.append(infer.move_batch_to_device(b, torch.device("cpu")))
if max_batches is not None and len(batches) >= max_batches:
break
del item_dict, user_seq, ds, loader
import gc
gc.collect()
model, dev = infer.load_model(ckpt_path=None)
logid2p = {}
t_sum = 0.0
cuda = (dev.type == "cuda")
with torch.inference_mode():
for b in batches:
b = infer.move_batch_to_device(b, dev)
pm = b["pred_mask"].bool()
if cuda:
torch.cuda.synchronize()
t0 = time.time()
logits, _ = model(b)
probs = torch.sigmoid(logits.squeeze(-1))
if cuda:
torch.cuda.synchronize()
t_sum += time.time() - t0
for lid, p in zip(b["logid"][pm].cpu().tolist(), probs[pm].cpu().tolist()):
logid2p[lid] = p
order = [int(l.split(",")[0]) for l in open(test_csv) if l.strip()]
missing = [lid for lid in order if lid not in logid2p]
if missing:
print(f"[BENCH][WARN] {len(missing)} 个测试 logid 没预测到(前几个 {missing[:5]}")
pred_path = cur / "predict.txt"
with open(pred_path, "w") as f:
for lid in order:
f.write(f"{logid2p.get(lid, 0.0)}\n")
res = infer._cal_score(pred_path, label_file, default_latency=t_sum)
print(
f"[BENCH] cfg={config_override} bs={batch_size}"
f"{'' if max_batches is None else f' (first {max_batches} batches)'}"
f" -> AUC={res['auc']:.5f} PCOC={res['pcoc']:.4f}"
f" lat={res['latency']:.2f}s score={res['score_all']:.2f}"
)
return res
def _parse_args():
import argparse
ap = argparse.ArgumentParser(description="CTI 推理测量闭环(子进程跑:!python bench.py ...")
ap.add_argument("--diag", action="store_true", help="只跑诊断,不推理")
ap.add_argument("--smoke", type=int, default=None, help="只跑前 N 个 batch(冒烟)")
ap.add_argument("--bs", type=int, default=50, help="batch_size(本地参考)")
ap.add_argument("--fp32", action="store_true", help="FP32 天花板 = 关 fp16 + 关 expert 合并")
ap.add_argument("--no-fp16", action="store_true", help="关闭半精度")
ap.add_argument("--no-merge", action="store_true", help="关闭 expert 合并")
ap.add_argument("--signid", choices=["clamp", "modulo"], default=None, help="sign-id 处理方式")
ap.add_argument("--merge-th", type=float, default=None, help="expert 合并余弦阈值")
ap.add_argument("--keep", type=str, default=None,
help="逗号分隔的 keep_fp32_modules,如 linear,rep_encoder.input_norm")
ap.add_argument("--feasign-none", action="store_true",
help="不截断特征(max_feasign_per_slot=None")
ap.add_argument("--attn", choices=["sdpa", "chunked", "flex", "varlen"], default=None,
help="注意力:sdpa=稠密, chunked=按用户分块SDPA, flex/varlen=对照")
ap.add_argument("--chunk-users", type=int, default=None, help="chunked 模式每块用户数")
ap.add_argument("--moe", choices=["dense", "loop"], default=None,
help="MoE实现:dense=向量化(新), loop=逐expert循环(原)")
ap.add_argument("--compile", action="store_true", help="开启 torch.compile")
ap.add_argument("--emb-fp16", action="store_true", help="Embedding表转FP16(查表带宽减半,测AUC)")
ap.add_argument("--dedup-emb", action="store_true", help="查表前对sign去重(减少大表随机访存)")
ap.add_argument("--sparse-pool", action="store_true", help="稀疏矩阵乘做池化(段内高重复时省)")
ap.add_argument("--profile", type=int, default=None, metavar="N",
help="剖析前 N 个 batch,打印按 CUDA 耗时排序的算子表(定位瓶颈)")
ap.add_argument("--rebuild", action="store_true", help="强制重建过滤缓存")
return ap.parse_args()
if __name__ == "__main__":
a = _parse_args()
if a.diag:
run_diag(rebuild=a.rebuild)
sys.exit(0)
cfg = {}
if a.fp32:
cfg["fp16"] = False
cfg["expert_merge"] = False
if a.no_fp16:
cfg["fp16"] = False
if a.no_merge:
cfg["expert_merge"] = False
if a.signid:
cfg["signid_mode"] = a.signid
if a.merge_th is not None:
cfg["merge_threshold"] = a.merge_th
if a.keep is not None:
cfg["keep_fp32_modules"] = tuple(x for x in a.keep.split(",") if x)
if a.attn is not None:
cfg["attn"] = a.attn
if a.chunk_users is not None:
cfg["chunk_users"] = a.chunk_users
if a.moe is not None:
cfg["vectorize_moe"] = (a.moe == "dense")
if a.emb_fp16:
cfg["emb_fp16"] = True
if a.dedup_emb:
cfg["dedup_embedding"] = True
if a.sparse_pool:
cfg["sparse_pool"] = True
if a.compile:
cfg["compile"] = True
if a.profile is not None:
run_profile(cfg, n=a.profile, batch_size=a.bs, rebuild=a.rebuild)
sys.exit(0)
mf = None if a.feasign_none else {1: 2}
run_once(cfg, batch_size=a.bs, max_batches=a.smoke, max_feasign_per_slot=mf, rebuild=a.rebuild)