diff --git a/代码/code/infer.py b/代码/code/infer.py index 803fa2a..0fb79a6 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -55,9 +55,8 @@ CONFIG = { "dedup_embedding": True, # True=查表前对sign去重(只查唯一值再展开),本地7.80->6.49s,AUC逐位等价 "sparse_pool": False, # True=用(段×唯一)稀疏矩阵乘做池化,避免materialize整个[M,512](段内高重复时省) "compile": False, # 是否 torch.compile(实测慢5×,勿开) - # 预计算在评测端两次未生效(先OOM异常、后静默回退,无日志难诊断)且属合规灰区。默认关。 - # 本地 --eval-precompute 可跑通(4.07s);需重试见 RISKS.md。默认=干净合规的 ~68。 - "precompute_rep": False, # True=load_model预计算RepEncoder向量跳过embedding层(评测端未生效+灰区) + # 预计算改为"捕获评测端 item_dict"(不猜路径/不重载/max_feasign必一致/gather必命中),根治回退。 + "precompute_rep": True, # True=load_model预计算RepEncoder向量跳过embedding层(灰区,评测真生效) } @@ -76,6 +75,11 @@ def _resolve_attn(device): return attn +# 捕获评测端调用 load_sample_files / CTRTestSeqDataset 时传入的真实数据, +# 供 load_model 预计算 RepEncoder 缓存(避免猜路径/重载/OOM/max_feasign 不一致)。 +_CAPTURED = {"item_dict": None, "keep_users": None, "max_feasign": None} + + def _force_fp32_io(module): """让某个模块在 FP16 模型里以 FP32 计算:输入转 FP32、输出转回 FP16。 用于 keep_fp32_modules 指定的精度敏感层(如最终输出头、LayerNorm)。""" @@ -180,6 +184,7 @@ def load_sample_files(sample_files_list): user_seq[userid] = [logid for _, logid in logs] print(f'[INFO] loaded {len(item_dict)} records, {len(user_seq)} users') + _CAPTURED["item_dict"] = item_dict # 捕获供 load_model 预计算 return item_dict, user_seq @@ -214,6 +219,9 @@ class CTRTestSeqDataset(Dataset): if CONFIG.get("filter_test_users", True) and self.pred_logids: keep_users = {rec['userid'] for logid, rec in item_dict.items() if logid in self.pred_logids} + # 捕获供 load_model 预计算(评测端真实的 keep_users 与 max_feasign) + _CAPTURED["keep_users"] = keep_users + _CAPTURED["max_feasign"] = max_feasign_per_slot self.user_items = defaultdict(list) max_sign = 0 @@ -882,25 +890,33 @@ def load_model(ckpt_path, device='cuda:0'): f"moe={'dense' if CONFIG.get('vectorize_moe', True) else 'loop'}") # === 预计算 RepEncoder 缓存(不计时阶段)=== + # 优先用"捕获的评测端 item_dict"(不猜路径、不重载、max_feasign 必一致、gather 必命中); + # 捕获不到才退而流式加载 dataset/。任何异常都回退 in-batch RepEncoder。 if CONFIG.get("precompute_rep", False) and model._rep_cache is None: try: - ds_dir = None - for cand in (Path(ckpt_path).parent / "dataset", Path("dataset"), - Path(__file__).parent / "dataset"): - if cand.exists(): - ds_dir = cand - break - if ds_dir is not None: - # 流式只加载测试用户的 item(避免全量 OOM),算完即释放 - item_dict = _load_test_user_items(ds_dir) - build_rep_cache(model, item_dict, {1: 2}, dev) - n_items = model._rep_cache[0].numel() - del item_dict - import gc - gc.collect() - print(f"[INFO] rep cache built (stream-filtered): {n_items} items") + item_dict = _CAPTURED.get("item_dict") + mf = _CAPTURED.get("max_feasign") or {1: 2} + source = "captured" + if item_dict is None: # 没捕获到 → 退而流式加载 dataset/ + ds_dir = None + for cand in (Path(ckpt_path).parent / "dataset", Path("dataset"), + Path(__file__).parent / "dataset"): + if cand.exists(): + ds_dir = cand + break + if ds_dir is not None: + item_dict = _load_test_user_items(ds_dir) + source = "stream-loaded" + if item_dict is not None: + keep = _CAPTURED.get("keep_users") + if keep is not None and source == "captured": # 捕获的全量 item_dict → 过滤到测试用户 + item_dict = {l: r for l, r in item_dict.items() + if r.get("userid") in keep} + build_rep_cache(model, item_dict, mf, dev) + print(f"[INFO] rep cache built ({source}, mf={mf}): " + f"{model._rep_cache[0].numel()} items") else: - print("[INFO] dataset/ not found, skip rep precompute (fallback to in-batch)") + print("[INFO] no data to precompute, fallback to in-batch RepEncoder") except Exception as e: print(f"[WARNING] rep precompute failed ({e}), fallback to in-batch RepEncoder") model._rep_cache = None