diff --git a/代码/code/bench.py b/代码/code/bench.py index 950da96..201c70b 100644 --- a/代码/code/bench.py +++ b/代码/code/bench.py @@ -104,29 +104,61 @@ def _load_filtered(history_dir, test_csv, test_users): return item_dict, user_seq -def _get_data(cur, ref, rebuild=False): - """取过滤后的 (item_dict, user_seq),优先读磁盘缓存。 +def _cache_path(cur): + return cur / "bench_filtered_cache.pkl" - 用 pickle 而非 torch.save/load:AI Studio overlay 文件系统对 torch 的 - zip/mmap 读取会间歇性报 [Errno 38] Function not implemented。 - """ - import pickle - cache = cur / "bench_filtered_cache.pkl" + +def _build_filtered(ref): test_csv = ref / "test.csv" history = ref / "history" + test_users = _test_user_ids(test_csv) + return _load_filtered(history, test_csv, test_users) + + +def _load_cache(cache): + import pickle + with open(cache, "rb") as f: + d = pickle.load(f) + return d["item_dict"], d["user_seq"] + + +def _save_cache(cache, item_dict, user_seq): + """原子写 + fsync + 写后校验;任何异常都不留毒文件。 + + 用 pickle 而非 torch.save:AI Studio overlay 文件系统对 torch 的 zip/mmap + 读取会间歇性报 [Errno 38]。pickle.dump 大对象较慢但顺序写更稳。 + """ + import pickle + tmp = str(cache) + ".tmp" + try: + with open(tmp, "wb") as f: + pickle.dump({"item_dict": item_dict, "user_seq": user_seq}, f, + protocol=pickle.HIGHEST_PROTOCOL) + f.flush() + os.fsync(f.fileno()) + os.replace(tmp, cache) + _load_cache(cache) # 写后立即校验可读 + print(f"[BENCH] 已缓存 -> {cache}") + except Exception as e: + print(f"[BENCH][WARN] 缓存写入失败({e}),本次不缓存(不影响结果)") + for p in (tmp, str(cache)): + try: + os.remove(p) + except OSError: + pass + + +def _get_data(cur, ref, rebuild=False): + """取过滤后的 (item_dict, user_seq),优先读磁盘缓存。""" + cache = _cache_path(cur) if cache.exists() and not rebuild: print(f"[BENCH] 读取过滤缓存:{cache}") try: - with open(cache, "rb") as f: - d = pickle.load(f) - return d["item_dict"], d["user_seq"] + return _load_cache(cache) except Exception as e: print(f"[BENCH][WARN] 缓存读取失败({e}),重新构建") - test_users = _test_user_ids(test_csv) - item_dict, user_seq = _load_filtered(history, test_csv, test_users) - with open(cache, "wb") as f: - pickle.dump({"item_dict": item_dict, "user_seq": user_seq}, f, protocol=4) - print(f"[BENCH] 已缓存 -> {cache}") + item_dict, user_seq = _build_filtered(ref) + _save_cache(cache, item_dict, user_seq) return item_dict, user_seq @@ -211,10 +243,25 @@ def run_once(config_override=None, batch_size=50, max_batches=None, def run_diag(rebuild=False): - """诊断:测试用户序列长度分布 + sign-id 是否超界(判断上下文与 modulo 的价值)。""" + """诊断:测试用户序列长度分布 + sign-id 是否超界(判断上下文与 modulo 的价值)。 + + 先打印诊断,再写缓存——避免缓存写入卡住时看不到诊断结果。 + """ cur = Path(__file__).parent ref = cur / "dataset" - item_dict, user_seq = _get_data(cur, ref, rebuild=rebuild) + cache = _cache_path(cur) + loaded = False + item_dict = user_seq = None + if cache.exists() and not rebuild: + print(f"[BENCH] 读取过滤缓存:{cache}") + try: + item_dict, user_seq = _load_cache(cache) + loaded = True + except Exception as e: + print(f"[BENCH][WARN] 缓存读取失败({e}),重新构建") + if not loaded: + item_dict, user_seq = _build_filtered(ref) + lens = np.array([len(v) for v in user_seq.values()]) if user_seq else np.array([0]) print(f"[DIAG] 测试用户数={len(user_seq)} 总记录数={len(item_dict)}") print(f"[DIAG] 每用户序列长度 min/median/mean/max = " @@ -235,6 +282,9 @@ def run_diag(rebuild=False): f"超界sign占比={over}/{tot}={(over / max(tot, 1)):.2%} " f"(占比高=clamp 在污染 embedding → modulo 可能找回 AUC)") + if not loaded: + _save_cache(_cache_path(cur), item_dict, user_seq) + def _parse_args(): import argparse