2268fa6cf3
profile显示embedding查表现为头号瓶颈(32%)。torch.unique去重后只查唯一sign 再按逆索引展开,数学逐位等价(AUC不变),省最贵的大表随机gather。bench --dedup-emb。 Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
344 lines
14 KiB
Python
344 lines
14 KiB
Python
"""本地测量闭环:设置 infer.CONFIG,跑推理,同步计时,打印 AUC/PCOC/延迟/总分。
|
||
|
||
不进提交包。**以子进程方式运行**(AI Studio 内核禁止 import torch):
|
||
|
||
%cd /home/aistudio/code
|
||
!python bench.py --diag # 诊断:序列长度分布 + sign-id 超界比例
|
||
!python bench.py --smoke 50 # 冒烟:只跑前 50 batch
|
||
!python bench.py # 默认基线
|
||
!python bench.py --fp32 # FP32 天花板
|
||
!python bench.py --rebuild # 强制重建过滤缓存
|
||
|
||
只保留“测试用户”的数据:不同用户被因果 mask 完全隔离,非测试用户的前向输出
|
||
不参与打分;过滤掉它们对测试样本的 AUC/PCOC 没有任何影响,却能把数据量从
|
||
924 万条降到一小部分。
|
||
|
||
缓存用**文本 CSV**而非 pickle:容器 cgroup 内存有限,pickle.dump 大对象的 memo
|
||
会瞬间撑爆内存被静默 OOM-kill;逐行写 CSV 内存几乎不涨,再用 load_sample_files
|
||
读回,稳。
|
||
"""
|
||
import os
|
||
import sys
|
||
import time
|
||
from collections import defaultdict
|
||
from pathlib import Path
|
||
|
||
# baseline 把依赖装在 --target 目录(非默认 site-packages),import 前先加 sys.path
|
||
for _p in ("/home/aistudio/external-libraries", "/home/aistudio/libraries",
|
||
os.path.abspath("../libraries"), os.path.abspath("./libraries")):
|
||
if os.path.isdir(_p) and _p not in sys.path:
|
||
sys.path.insert(0, _p)
|
||
|
||
import numpy as np
|
||
import torch
|
||
from torch.utils.data import DataLoader
|
||
|
||
import infer # 同目录
|
||
|
||
|
||
def _test_user_ids(test_csv):
|
||
"""从 test.csv 读出所有测试用户 id(第 2 列 userid)。"""
|
||
users = set()
|
||
with open(test_csv) as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
parts = line.split(",")
|
||
if len(parts) >= 2:
|
||
users.add(int(parts[1]))
|
||
return users
|
||
|
||
|
||
def _stream_build(ref, cache_csv_path=None):
|
||
"""流式过滤:构建 item_dict/user_seq;若给 cache_csv_path,同时把保留的历史行
|
||
原样逐行写入(低内存文本缓存,test.csv 直接复用、不进缓存)。
|
||
"""
|
||
test_csv = ref / "test.csv"
|
||
history = ref / "history"
|
||
test_users = _test_user_ids(test_csv)
|
||
files = (sorted(history.glob("*.csv")) if history.exists() else []) + [test_csv]
|
||
print(f"[BENCH] 流式过滤加载 {len(files)} 个文件(仅保留 {len(test_users)} 个测试用户)...")
|
||
|
||
item_dict = {}
|
||
user_logs = defaultdict(list)
|
||
cf = open(cache_csv_path, "w") if cache_csv_path else None
|
||
try:
|
||
for fp in files:
|
||
has_clk = infer._detect_has_clk(fp)
|
||
min_parts = 5 if has_clk else 4
|
||
is_test = (Path(fp).name == test_csv.name)
|
||
kept = 0
|
||
with open(fp) as f:
|
||
for raw in f:
|
||
line = raw.strip()
|
||
if not line:
|
||
continue
|
||
parts = line.split(",")
|
||
if len(parts) < min_parts:
|
||
continue
|
||
userid = int(parts[1])
|
||
if userid not in test_users:
|
||
continue
|
||
if cf is not None and not is_test: # 只缓存历史行
|
||
cf.write(raw if raw.endswith("\n") else raw + "\n")
|
||
logid = int(parts[0])
|
||
adid = int(parts[2])
|
||
if has_clk:
|
||
clk = int(parts[3])
|
||
timestamp = int(parts[4])
|
||
fs = 5
|
||
else:
|
||
clk = 0
|
||
timestamp = int(parts[3])
|
||
fs = 4
|
||
signs, slots = [], []
|
||
for pair in parts[fs:]:
|
||
if ":" in pair:
|
||
s, sl = pair.split(":", 1)
|
||
signs.append(int(s))
|
||
slots.append(int(sl))
|
||
item_dict[logid] = {
|
||
"logid": logid, "userid": userid, "adid": adid,
|
||
"clk": clk, "timestamp": timestamp,
|
||
"signs": np.array(signs, dtype=np.int64),
|
||
"slots": np.array(slots, dtype=np.int64),
|
||
}
|
||
user_logs[userid].append((timestamp, logid))
|
||
kept += 1
|
||
print(f" {Path(fp).name}: has_clk={has_clk}, kept={kept}")
|
||
finally:
|
||
if cf is not None:
|
||
cf.flush()
|
||
os.fsync(cf.fileno())
|
||
cf.close()
|
||
|
||
user_seq = {}
|
||
for u, logs in user_logs.items():
|
||
logs.sort(key=lambda x: x[0])
|
||
user_seq[u] = [lid for _, lid in logs]
|
||
print(f"[BENCH] 过滤后:{len(item_dict)} 条记录,{len(user_seq)} 个用户")
|
||
if cache_csv_path:
|
||
print(f"[BENCH] 已缓存历史行 -> {cache_csv_path}(下次快速读取)")
|
||
return item_dict, user_seq
|
||
|
||
|
||
def _get_data(cur, ref, rebuild=False):
|
||
"""取过滤后的 (item_dict, user_seq),优先读 CSV 缓存。"""
|
||
cache_csv = cur / "cache_filtered_history.csv"
|
||
test_csv = ref / "test.csv"
|
||
if cache_csv.exists() and not rebuild:
|
||
print(f"[BENCH] 读取过滤缓存(CSV):{cache_csv}")
|
||
try:
|
||
return infer.load_sample_files([str(cache_csv), str(test_csv)])
|
||
except Exception as e:
|
||
print(f"[BENCH][WARN] 缓存读取失败({e}),重新构建")
|
||
return _stream_build(ref, cache_csv_path=str(cache_csv))
|
||
|
||
|
||
def run_diag(rebuild=False):
|
||
"""诊断:测试用户序列长度分布 + sign-id 是否超界(判断上下文与 modulo 的价值)。"""
|
||
cur = Path(__file__).parent
|
||
ref = cur / "dataset"
|
||
item_dict, user_seq = _get_data(cur, ref, rebuild=rebuild)
|
||
lens = np.array([len(v) for v in user_seq.values()]) if user_seq else np.array([0])
|
||
print(f"[DIAG] 测试用户数={len(user_seq)} 总记录数={len(item_dict)}")
|
||
print(f"[DIAG] 每用户序列长度 min/median/mean/max = "
|
||
f"{int(lens.min())}/{int(np.median(lens))}/{lens.mean():.1f}/{int(lens.max())}")
|
||
print(f"[DIAG] 序列长度>1 的用户占比 = {(lens > 1).mean():.1%}")
|
||
VOCAB = 5_000_000
|
||
mx, over, tot = 0, 0, 0
|
||
for rec in item_dict.values():
|
||
s = rec["signs"]
|
||
if s.size:
|
||
m = int(s.max())
|
||
if m > mx:
|
||
mx = m
|
||
over += int((s >= VOCAB).sum())
|
||
tot += int(s.size)
|
||
print(f"[DIAG] max_sign_id={mx} vocab={VOCAB} "
|
||
f"超界sign占比={over}/{tot}={(over / max(tot, 1)):.2%}")
|
||
|
||
|
||
def run_profile(config_override=None, n=20, batch_size=50, rebuild=False):
|
||
"""用 torch.profiler 剖析前 n 个 batch,打印按 CUDA 耗时排序的算子表,定位真正瓶颈。"""
|
||
if config_override is None:
|
||
config_override = {}
|
||
infer.CONFIG.update(config_override)
|
||
cur = Path(__file__).parent
|
||
ref = cur / "dataset"
|
||
item_dict, user_seq = _get_data(cur, ref, rebuild=rebuild)
|
||
test_logids = infer.load_logids_from_file(ref / "test.csv")
|
||
ds = infer.CTRTestSeqDataset(
|
||
test_logids_ordered=list(test_logids), item_dict=item_dict,
|
||
user_seq=user_seq, max_feasign_per_slot={1: 2}, max_ctx_len=None)
|
||
loader = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0,
|
||
collate_fn=infer.make_collate_fn(ds.max_slot_id))
|
||
batches = []
|
||
for b in loader:
|
||
batches.append(infer.move_batch_to_device(b, torch.device("cpu")))
|
||
if len(batches) >= n:
|
||
break
|
||
del item_dict, user_seq, ds, loader
|
||
import gc
|
||
gc.collect()
|
||
model, dev = infer.load_model(ckpt_path=None)
|
||
cuda = (dev.type == "cuda")
|
||
from torch.profiler import profile, ProfilerActivity
|
||
acts = [ProfilerActivity.CPU] + ([ProfilerActivity.CUDA] if cuda else [])
|
||
with torch.inference_mode():
|
||
warm = infer.move_batch_to_device(batches[0], dev) # 预热(触发任何首次编译)
|
||
model(warm)
|
||
if cuda:
|
||
torch.cuda.synchronize()
|
||
with profile(activities=acts) as prof:
|
||
for b in batches:
|
||
b = infer.move_batch_to_device(b, dev)
|
||
model(b)
|
||
if cuda:
|
||
torch.cuda.synchronize()
|
||
sort_key = "cuda_time_total" if cuda else "cpu_time_total"
|
||
print(prof.key_averages().table(sort_by=sort_key, row_limit=25))
|
||
|
||
|
||
def run_once(config_override=None, batch_size=50, max_batches=None,
|
||
max_feasign_per_slot=None, rebuild=False):
|
||
"""跑一次本地推理并打分。返回 infer._cal_score 的结果 dict。"""
|
||
if config_override is None:
|
||
config_override = {}
|
||
if max_feasign_per_slot is None:
|
||
max_feasign_per_slot = {1: 2}
|
||
|
||
infer.CONFIG.update(config_override)
|
||
infer.CONFIG["sync_timing"] = True
|
||
|
||
cur = Path(__file__).parent
|
||
ref = cur / "dataset"
|
||
test_csv = ref / "test.csv"
|
||
label_file = ref / "label_data.txt"
|
||
|
||
item_dict, user_seq = _get_data(cur, ref, rebuild=rebuild)
|
||
test_logids = infer.load_logids_from_file(test_csv)
|
||
ds = infer.CTRTestSeqDataset(
|
||
test_logids_ordered=list(test_logids), item_dict=item_dict,
|
||
user_seq=user_seq, max_feasign_per_slot=max_feasign_per_slot, max_ctx_len=None,
|
||
)
|
||
loader = DataLoader(
|
||
ds, batch_size=batch_size, shuffle=False, num_workers=0,
|
||
collate_fn=infer.make_collate_fn(ds.max_slot_id),
|
||
)
|
||
batches = []
|
||
for b in loader:
|
||
batches.append(infer.move_batch_to_device(b, torch.device("cpu")))
|
||
if max_batches is not None and len(batches) >= max_batches:
|
||
break
|
||
|
||
del item_dict, user_seq, ds, loader
|
||
import gc
|
||
gc.collect()
|
||
|
||
model, dev = infer.load_model(ckpt_path=None)
|
||
|
||
logid2p = {}
|
||
t_sum = 0.0
|
||
cuda = (dev.type == "cuda")
|
||
with torch.inference_mode():
|
||
for b in batches:
|
||
b = infer.move_batch_to_device(b, dev)
|
||
pm = b["pred_mask"].bool()
|
||
if cuda:
|
||
torch.cuda.synchronize()
|
||
t0 = time.time()
|
||
logits, _ = model(b)
|
||
probs = torch.sigmoid(logits.squeeze(-1))
|
||
if cuda:
|
||
torch.cuda.synchronize()
|
||
t_sum += time.time() - t0
|
||
for lid, p in zip(b["logid"][pm].cpu().tolist(), probs[pm].cpu().tolist()):
|
||
logid2p[lid] = p
|
||
|
||
order = [int(l.split(",")[0]) for l in open(test_csv) if l.strip()]
|
||
missing = [lid for lid in order if lid not in logid2p]
|
||
if missing:
|
||
print(f"[BENCH][WARN] {len(missing)} 个测试 logid 没预测到(前几个 {missing[:5]})")
|
||
pred_path = cur / "predict.txt"
|
||
with open(pred_path, "w") as f:
|
||
for lid in order:
|
||
f.write(f"{logid2p.get(lid, 0.0)}\n")
|
||
|
||
res = infer._cal_score(pred_path, label_file, default_latency=t_sum)
|
||
print(
|
||
f"[BENCH] cfg={config_override} bs={batch_size}"
|
||
f"{'' if max_batches is None else f' (first {max_batches} batches)'}"
|
||
f" -> AUC={res['auc']:.5f} PCOC={res['pcoc']:.4f}"
|
||
f" lat={res['latency']:.2f}s score={res['score_all']:.2f}"
|
||
)
|
||
return res
|
||
|
||
|
||
def _parse_args():
|
||
import argparse
|
||
ap = argparse.ArgumentParser(description="CTI 推理测量闭环(子进程跑:!python bench.py ...)")
|
||
ap.add_argument("--diag", action="store_true", help="只跑诊断,不推理")
|
||
ap.add_argument("--smoke", type=int, default=None, help="只跑前 N 个 batch(冒烟)")
|
||
ap.add_argument("--bs", type=int, default=50, help="batch_size(本地参考)")
|
||
ap.add_argument("--fp32", action="store_true", help="FP32 天花板 = 关 fp16 + 关 expert 合并")
|
||
ap.add_argument("--no-fp16", action="store_true", help="关闭半精度")
|
||
ap.add_argument("--no-merge", action="store_true", help="关闭 expert 合并")
|
||
ap.add_argument("--signid", choices=["clamp", "modulo"], default=None, help="sign-id 处理方式")
|
||
ap.add_argument("--merge-th", type=float, default=None, help="expert 合并余弦阈值")
|
||
ap.add_argument("--keep", type=str, default=None,
|
||
help="逗号分隔的 keep_fp32_modules,如 linear,rep_encoder.input_norm")
|
||
ap.add_argument("--feasign-none", action="store_true",
|
||
help="不截断特征(max_feasign_per_slot=None)")
|
||
ap.add_argument("--attn", choices=["sdpa", "chunked", "flex", "varlen"], default=None,
|
||
help="注意力:sdpa=稠密, chunked=按用户分块SDPA, flex/varlen=对照")
|
||
ap.add_argument("--chunk-users", type=int, default=None, help="chunked 模式每块用户数")
|
||
ap.add_argument("--moe", choices=["dense", "loop"], default=None,
|
||
help="MoE实现:dense=向量化(新), loop=逐expert循环(原)")
|
||
ap.add_argument("--compile", action="store_true", help="开启 torch.compile")
|
||
ap.add_argument("--emb-fp16", action="store_true", help="Embedding表转FP16(查表带宽减半,测AUC)")
|
||
ap.add_argument("--dedup-emb", action="store_true", help="查表前对sign去重(减少大表随机访存)")
|
||
ap.add_argument("--profile", type=int, default=None, metavar="N",
|
||
help="剖析前 N 个 batch,打印按 CUDA 耗时排序的算子表(定位瓶颈)")
|
||
ap.add_argument("--rebuild", action="store_true", help="强制重建过滤缓存")
|
||
return ap.parse_args()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
a = _parse_args()
|
||
if a.diag:
|
||
run_diag(rebuild=a.rebuild)
|
||
sys.exit(0)
|
||
cfg = {}
|
||
if a.fp32:
|
||
cfg["fp16"] = False
|
||
cfg["expert_merge"] = False
|
||
if a.no_fp16:
|
||
cfg["fp16"] = False
|
||
if a.no_merge:
|
||
cfg["expert_merge"] = False
|
||
if a.signid:
|
||
cfg["signid_mode"] = a.signid
|
||
if a.merge_th is not None:
|
||
cfg["merge_threshold"] = a.merge_th
|
||
if a.keep is not None:
|
||
cfg["keep_fp32_modules"] = tuple(x for x in a.keep.split(",") if x)
|
||
if a.attn is not None:
|
||
cfg["attn"] = a.attn
|
||
if a.chunk_users is not None:
|
||
cfg["chunk_users"] = a.chunk_users
|
||
if a.moe is not None:
|
||
cfg["vectorize_moe"] = (a.moe == "dense")
|
||
if a.emb_fp16:
|
||
cfg["emb_fp16"] = True
|
||
if a.dedup_emb:
|
||
cfg["dedup_embedding"] = True
|
||
if a.compile:
|
||
cfg["compile"] = True
|
||
if a.profile is not None:
|
||
run_profile(cfg, n=a.profile, batch_size=a.bs, rebuild=a.rebuild)
|
||
sys.exit(0)
|
||
mf = None if a.feasign_none else {1: 2}
|
||
run_once(cfg, batch_size=a.bs, max_batches=a.smoke, max_feasign_per_slot=mf, rebuild=a.rebuild)
|