diff --git a/代码/code/bench.py b/代码/code/bench.py index 54c9da3..6ac3c9d 100644 --- a/代码/code/bench.py +++ b/代码/code/bench.py @@ -200,9 +200,36 @@ def run_once(config_override=None, batch_size=50, max_batches=None, return res +def run_diag(rebuild=False): + """诊断:测试用户序列长度分布 + sign-id 是否超界(判断上下文与 modulo 的价值)。""" + cur = Path(__file__).parent + ref = cur / "dataset" + item_dict, user_seq = _get_data(cur, ref, rebuild=rebuild) + lens = np.array([len(v) for v in user_seq.values()]) if user_seq else np.array([0]) + print(f"[DIAG] 测试用户数={len(user_seq)} 总记录数={len(item_dict)}") + print(f"[DIAG] 每用户序列长度 min/median/mean/max = " + f"{int(lens.min())}/{int(np.median(lens))}/{lens.mean():.1f}/{int(lens.max())}") + print(f"[DIAG] 序列长度>1 的用户占比 = {(lens > 1).mean():.1%} " + f"(占比低=大量测试样本没有历史上下文 → 生成式模型发挥不出来)") + VOCAB = 5_000_000 + mx, over, tot = 0, 0, 0 + for rec in item_dict.values(): + s = rec["signs"] + if s.size: + m = int(s.max()) + if m > mx: + mx = m + over += int((s >= VOCAB).sum()) + tot += int(s.size) + print(f"[DIAG] max_sign_id={mx} vocab={VOCAB} " + f"超界sign占比={over}/{tot}={(over / max(tot, 1)):.2%} " + f"(占比高=clamp 在污染 embedding → modulo 可能找回 AUC)") + + def _parse_args(): import argparse ap = argparse.ArgumentParser(description="CTI 推理测量闭环(子进程跑:!python bench.py ...)") + ap.add_argument("--diag", action="store_true", help="只跑诊断(序列长度分布 + sign-id 超界比例),不推理") ap.add_argument("--smoke", type=int, default=None, help="只跑前 N 个 batch(冒烟)") ap.add_argument("--bs", type=int, default=50, help="batch_size(本地参考)") ap.add_argument("--fp32", action="store_true", help="FP32 天花板 = 关 fp16 + 关 expert 合并") @@ -220,6 +247,9 @@ def _parse_args(): if __name__ == "__main__": a = _parse_args() + if a.diag: + run_diag(rebuild=a.rebuild) + sys.exit(0) cfg = {} if a.fp32: cfg["fp16"] = False