feat/auc-recovery-plan #1

Merged
Serendipity merged 20 commits from feat/auc-recovery-plan into main 2026-06-15 12:33:32 +08:00
Showing only changes of commit e7b542a389 - Show all commits
+67 -17
View File
@@ -104,29 +104,61 @@ def _load_filtered(history_dir, test_csv, test_users):
return item_dict, user_seq
def _get_data(cur, ref, rebuild=False):
"""取过滤后的 (item_dict, user_seq),优先读磁盘缓存。
def _cache_path(cur):
return cur / "bench_filtered_cache.pkl"
用 pickle 而非 torch.save/loadAI Studio overlay 文件系统对 torch 的
zip/mmap 读取会间歇性报 [Errno 38] Function not implemented。
"""
import pickle
cache = cur / "bench_filtered_cache.pkl"
def _build_filtered(ref):
test_csv = ref / "test.csv"
history = ref / "history"
if cache.exists() and not rebuild:
print(f"[BENCH] 读取过滤缓存:{cache}")
try:
test_users = _test_user_ids(test_csv)
return _load_filtered(history, test_csv, test_users)
def _load_cache(cache):
import pickle
with open(cache, "rb") as f:
d = pickle.load(f)
return d["item_dict"], d["user_seq"]
def _save_cache(cache, item_dict, user_seq):
"""原子写 + fsync + 写后校验;任何异常都不留毒文件。
用 pickle 而非 torch.saveAI Studio overlay 文件系统对 torch 的 zip/mmap
读取会间歇性报 [Errno 38]。pickle.dump 大对象较慢但顺序写更稳。
"""
import pickle
tmp = str(cache) + ".tmp"
try:
with open(tmp, "wb") as f:
pickle.dump({"item_dict": item_dict, "user_seq": user_seq}, f,
protocol=pickle.HIGHEST_PROTOCOL)
f.flush()
os.fsync(f.fileno())
os.replace(tmp, cache)
_load_cache(cache) # 写后立即校验可读
print(f"[BENCH] 已缓存 -> {cache}")
except Exception as e:
print(f"[BENCH][WARN] 缓存写入失败({e}),本次不缓存(不影响结果)")
for p in (tmp, str(cache)):
try:
os.remove(p)
except OSError:
pass
def _get_data(cur, ref, rebuild=False):
"""取过滤后的 (item_dict, user_seq),优先读磁盘缓存。"""
cache = _cache_path(cur)
if cache.exists() and not rebuild:
print(f"[BENCH] 读取过滤缓存:{cache}")
try:
return _load_cache(cache)
except Exception as e:
print(f"[BENCH][WARN] 缓存读取失败({e}),重新构建")
test_users = _test_user_ids(test_csv)
item_dict, user_seq = _load_filtered(history, test_csv, test_users)
with open(cache, "wb") as f:
pickle.dump({"item_dict": item_dict, "user_seq": user_seq}, f, protocol=4)
print(f"[BENCH] 已缓存 -> {cache}")
item_dict, user_seq = _build_filtered(ref)
_save_cache(cache, item_dict, user_seq)
return item_dict, user_seq
@@ -211,10 +243,25 @@ def run_once(config_override=None, batch_size=50, max_batches=None,
def run_diag(rebuild=False):
"""诊断:测试用户序列长度分布 + sign-id 是否超界(判断上下文与 modulo 的价值)。"""
"""诊断:测试用户序列长度分布 + sign-id 是否超界(判断上下文与 modulo 的价值)。
先打印诊断,再写缓存——避免缓存写入卡住时看不到诊断结果。
"""
cur = Path(__file__).parent
ref = cur / "dataset"
item_dict, user_seq = _get_data(cur, ref, rebuild=rebuild)
cache = _cache_path(cur)
loaded = False
item_dict = user_seq = None
if cache.exists() and not rebuild:
print(f"[BENCH] 读取过滤缓存:{cache}")
try:
item_dict, user_seq = _load_cache(cache)
loaded = True
except Exception as e:
print(f"[BENCH][WARN] 缓存读取失败({e}),重新构建")
if not loaded:
item_dict, user_seq = _build_filtered(ref)
lens = np.array([len(v) for v in user_seq.values()]) if user_seq else np.array([0])
print(f"[DIAG] 测试用户数={len(user_seq)} 总记录数={len(item_dict)}")
print(f"[DIAG] 每用户序列长度 min/median/mean/max = "
@@ -235,6 +282,9 @@ def run_diag(rebuild=False):
f"超界sign占比={over}/{tot}={(over / max(tot, 1)):.2%} "
f"(占比高=clamp 在污染 embedding → modulo 可能找回 AUC")
if not loaded:
_save_cache(_cache_path(cur), item_dict, user_seq)
def _parse_args():
import argparse