fix: 缓存原子写+fsync+校验,diag 先打印再缓存(防卡住看不到诊断)
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
+67
-17
@@ -104,29 +104,61 @@ def _load_filtered(history_dir, test_csv, test_users):
|
|||||||
return item_dict, user_seq
|
return item_dict, user_seq
|
||||||
|
|
||||||
|
|
||||||
def _get_data(cur, ref, rebuild=False):
|
def _cache_path(cur):
|
||||||
"""取过滤后的 (item_dict, user_seq),优先读磁盘缓存。
|
return cur / "bench_filtered_cache.pkl"
|
||||||
|
|
||||||
用 pickle 而非 torch.save/load:AI Studio overlay 文件系统对 torch 的
|
|
||||||
zip/mmap 读取会间歇性报 [Errno 38] Function not implemented。
|
def _build_filtered(ref):
|
||||||
"""
|
|
||||||
import pickle
|
|
||||||
cache = cur / "bench_filtered_cache.pkl"
|
|
||||||
test_csv = ref / "test.csv"
|
test_csv = ref / "test.csv"
|
||||||
history = ref / "history"
|
history = ref / "history"
|
||||||
if cache.exists() and not rebuild:
|
test_users = _test_user_ids(test_csv)
|
||||||
print(f"[BENCH] 读取过滤缓存:{cache}")
|
return _load_filtered(history, test_csv, test_users)
|
||||||
try:
|
|
||||||
|
|
||||||
|
def _load_cache(cache):
|
||||||
|
import pickle
|
||||||
with open(cache, "rb") as f:
|
with open(cache, "rb") as f:
|
||||||
d = pickle.load(f)
|
d = pickle.load(f)
|
||||||
return d["item_dict"], d["user_seq"]
|
return d["item_dict"], d["user_seq"]
|
||||||
|
|
||||||
|
|
||||||
|
def _save_cache(cache, item_dict, user_seq):
|
||||||
|
"""原子写 + fsync + 写后校验;任何异常都不留毒文件。
|
||||||
|
|
||||||
|
用 pickle 而非 torch.save:AI Studio overlay 文件系统对 torch 的 zip/mmap
|
||||||
|
读取会间歇性报 [Errno 38]。pickle.dump 大对象较慢但顺序写更稳。
|
||||||
|
"""
|
||||||
|
import pickle
|
||||||
|
tmp = str(cache) + ".tmp"
|
||||||
|
try:
|
||||||
|
with open(tmp, "wb") as f:
|
||||||
|
pickle.dump({"item_dict": item_dict, "user_seq": user_seq}, f,
|
||||||
|
protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
f.flush()
|
||||||
|
os.fsync(f.fileno())
|
||||||
|
os.replace(tmp, cache)
|
||||||
|
_load_cache(cache) # 写后立即校验可读
|
||||||
|
print(f"[BENCH] 已缓存 -> {cache}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[BENCH][WARN] 缓存写入失败({e}),本次不缓存(不影响结果)")
|
||||||
|
for p in (tmp, str(cache)):
|
||||||
|
try:
|
||||||
|
os.remove(p)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _get_data(cur, ref, rebuild=False):
|
||||||
|
"""取过滤后的 (item_dict, user_seq),优先读磁盘缓存。"""
|
||||||
|
cache = _cache_path(cur)
|
||||||
|
if cache.exists() and not rebuild:
|
||||||
|
print(f"[BENCH] 读取过滤缓存:{cache}")
|
||||||
|
try:
|
||||||
|
return _load_cache(cache)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[BENCH][WARN] 缓存读取失败({e}),重新构建")
|
print(f"[BENCH][WARN] 缓存读取失败({e}),重新构建")
|
||||||
test_users = _test_user_ids(test_csv)
|
item_dict, user_seq = _build_filtered(ref)
|
||||||
item_dict, user_seq = _load_filtered(history, test_csv, test_users)
|
_save_cache(cache, item_dict, user_seq)
|
||||||
with open(cache, "wb") as f:
|
|
||||||
pickle.dump({"item_dict": item_dict, "user_seq": user_seq}, f, protocol=4)
|
|
||||||
print(f"[BENCH] 已缓存 -> {cache}")
|
|
||||||
return item_dict, user_seq
|
return item_dict, user_seq
|
||||||
|
|
||||||
|
|
||||||
@@ -211,10 +243,25 @@ def run_once(config_override=None, batch_size=50, max_batches=None,
|
|||||||
|
|
||||||
|
|
||||||
def run_diag(rebuild=False):
|
def run_diag(rebuild=False):
|
||||||
"""诊断:测试用户序列长度分布 + sign-id 是否超界(判断上下文与 modulo 的价值)。"""
|
"""诊断:测试用户序列长度分布 + sign-id 是否超界(判断上下文与 modulo 的价值)。
|
||||||
|
|
||||||
|
先打印诊断,再写缓存——避免缓存写入卡住时看不到诊断结果。
|
||||||
|
"""
|
||||||
cur = Path(__file__).parent
|
cur = Path(__file__).parent
|
||||||
ref = cur / "dataset"
|
ref = cur / "dataset"
|
||||||
item_dict, user_seq = _get_data(cur, ref, rebuild=rebuild)
|
cache = _cache_path(cur)
|
||||||
|
loaded = False
|
||||||
|
item_dict = user_seq = None
|
||||||
|
if cache.exists() and not rebuild:
|
||||||
|
print(f"[BENCH] 读取过滤缓存:{cache}")
|
||||||
|
try:
|
||||||
|
item_dict, user_seq = _load_cache(cache)
|
||||||
|
loaded = True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[BENCH][WARN] 缓存读取失败({e}),重新构建")
|
||||||
|
if not loaded:
|
||||||
|
item_dict, user_seq = _build_filtered(ref)
|
||||||
|
|
||||||
lens = np.array([len(v) for v in user_seq.values()]) if user_seq else np.array([0])
|
lens = np.array([len(v) for v in user_seq.values()]) if user_seq else np.array([0])
|
||||||
print(f"[DIAG] 测试用户数={len(user_seq)} 总记录数={len(item_dict)}")
|
print(f"[DIAG] 测试用户数={len(user_seq)} 总记录数={len(item_dict)}")
|
||||||
print(f"[DIAG] 每用户序列长度 min/median/mean/max = "
|
print(f"[DIAG] 每用户序列长度 min/median/mean/max = "
|
||||||
@@ -235,6 +282,9 @@ def run_diag(rebuild=False):
|
|||||||
f"超界sign占比={over}/{tot}={(over / max(tot, 1)):.2%} "
|
f"超界sign占比={over}/{tot}={(over / max(tot, 1)):.2%} "
|
||||||
f"(占比高=clamp 在污染 embedding → modulo 可能找回 AUC)")
|
f"(占比高=clamp 在污染 embedding → modulo 可能找回 AUC)")
|
||||||
|
|
||||||
|
if not loaded:
|
||||||
|
_save_cache(_cache_path(cur), item_dict, user_seq)
|
||||||
|
|
||||||
|
|
||||||
def _parse_args():
|
def _parse_args():
|
||||||
import argparse
|
import argparse
|
||||||
|
|||||||
Reference in New Issue
Block a user