fix: 修OOM — load_model预计算改流式只加载测试用户+直接逐item算(不建Dataset)+算完释放
评测异常根因:load_model全量load_sample_files与评测自身数据双倍内存OOM。 改:_load_test_user_items流式过滤(仅测试用户~1.5M)、build_rep_cache直接从item_dict 逐item算(省掉user_items~8GB拷贝)、算完del+gc。bench加--eval-precompute本地真跑 load_model这条路验证不OOM。 Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
+12
-3
@@ -209,11 +209,13 @@ def run_once(config_override=None, batch_size=50, max_batches=None,
|
|||||||
if max_feasign_per_slot is None:
|
if max_feasign_per_slot is None:
|
||||||
max_feasign_per_slot = {1: 2}
|
max_feasign_per_slot = {1: 2}
|
||||||
|
|
||||||
# 本地用已加载的过滤数据自建 rep 缓存,禁止 load_model 自动加载全量数据集
|
# precompute_rep: 从已加载的过滤 batches 自建缓存(测 gather);
|
||||||
|
# eval_precompute: 走真正的评测路径(load_model 流式过滤自动预计算)
|
||||||
want_precompute = bool(config_override.pop("precompute_rep", False))
|
want_precompute = bool(config_override.pop("precompute_rep", False))
|
||||||
|
eval_precompute = bool(config_override.pop("eval_precompute", False))
|
||||||
infer.CONFIG.update(config_override)
|
infer.CONFIG.update(config_override)
|
||||||
infer.CONFIG["sync_timing"] = True
|
infer.CONFIG["sync_timing"] = True
|
||||||
infer.CONFIG["precompute_rep"] = False
|
infer.CONFIG["precompute_rep"] = eval_precompute # True 时让 load_model 自动预计算
|
||||||
|
|
||||||
cur = Path(__file__).parent
|
cur = Path(__file__).parent
|
||||||
ref = cur / "dataset"
|
ref = cur / "dataset"
|
||||||
@@ -243,8 +245,11 @@ def run_once(config_override=None, batch_size=50, max_batches=None,
|
|||||||
model, dev = infer.load_model(ckpt_path=None)
|
model, dev = infer.load_model(ckpt_path=None)
|
||||||
cuda = (dev.type == "cuda")
|
cuda = (dev.type == "cuda")
|
||||||
|
|
||||||
|
if eval_precompute and model._rep_cache is not None:
|
||||||
|
print(f"[BENCH] eval-path rep cache (load_model): {model._rep_cache[0].numel()} items")
|
||||||
|
|
||||||
# 本地从已建好的 batches 构造 rep 缓存(复用 batches、省内存;不计入计时)
|
# 本地从已建好的 batches 构造 rep 缓存(复用 batches、省内存;不计入计时)
|
||||||
if want_precompute:
|
if want_precompute and not eval_precompute:
|
||||||
lc, ec = [], []
|
lc, ec = [], []
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
for b in batches:
|
for b in batches:
|
||||||
@@ -320,6 +325,8 @@ def _parse_args():
|
|||||||
ap.add_argument("--sparse-pool", action="store_true", help="稀疏矩阵乘做池化(段内高重复时省)")
|
ap.add_argument("--sparse-pool", action="store_true", help="稀疏矩阵乘做池化(段内高重复时省)")
|
||||||
ap.add_argument("--precompute-rep", action="store_true",
|
ap.add_argument("--precompute-rep", action="store_true",
|
||||||
help="预计算RepEncoder缓存,model(batch)跳过embedding层(从batches自建)")
|
help="预计算RepEncoder缓存,model(batch)跳过embedding层(从batches自建)")
|
||||||
|
ap.add_argument("--eval-precompute", action="store_true",
|
||||||
|
help="走评测路径:load_model 流式过滤自动预计算(本地验证不OOM)")
|
||||||
ap.add_argument("--profile", type=int, default=None, metavar="N",
|
ap.add_argument("--profile", type=int, default=None, metavar="N",
|
||||||
help="剖析前 N 个 batch,打印按 CUDA 耗时排序的算子表(定位瓶颈)")
|
help="剖析前 N 个 batch,打印按 CUDA 耗时排序的算子表(定位瓶颈)")
|
||||||
ap.add_argument("--rebuild", action="store_true", help="强制重建过滤缓存")
|
ap.add_argument("--rebuild", action="store_true", help="强制重建过滤缓存")
|
||||||
@@ -359,6 +366,8 @@ if __name__ == "__main__":
|
|||||||
cfg["sparse_pool"] = True
|
cfg["sparse_pool"] = True
|
||||||
if a.precompute_rep:
|
if a.precompute_rep:
|
||||||
cfg["precompute_rep"] = True
|
cfg["precompute_rep"] = True
|
||||||
|
if a.eval_precompute:
|
||||||
|
cfg["eval_precompute"] = True
|
||||||
if a.compile:
|
if a.compile:
|
||||||
cfg["compile"] = True
|
cfg["compile"] = True
|
||||||
if a.profile is not None:
|
if a.profile is not None:
|
||||||
|
|||||||
+79
-27
@@ -720,31 +720,82 @@ class CTRModel(nn.Module):
|
|||||||
# RepEncoder 预计算缓存
|
# RepEncoder 预计算缓存
|
||||||
# ============================================================
|
# ============================================================
|
||||||
|
|
||||||
def build_rep_cache(model, item_dict, user_seq, test_logids_ordered,
|
def _load_test_user_items(ds_dir):
|
||||||
max_feasign_per_slot, device, batch_users=200):
|
"""流式只加载"测试用户"的 item(避免全量 OOM)。返回 item_dict(仅测试用户)。"""
|
||||||
"""预计算所有 item 的 RepEncoder 向量(context-free),按 logid 排序存入 model._rep_cache。
|
test_csv = ds_dir / "test.csv"
|
||||||
|
history = ds_dir / "history"
|
||||||
|
test_users = set()
|
||||||
|
with open(test_csv) as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
parts = line.split(",")
|
||||||
|
if len(parts) >= 2:
|
||||||
|
test_users.add(int(parts[1]))
|
||||||
|
files = (sorted(history.glob("*.csv")) if history.exists() else []) + [test_csv]
|
||||||
|
item_dict = {}
|
||||||
|
for fp in files:
|
||||||
|
has_clk = _detect_has_clk(fp)
|
||||||
|
min_parts = 5 if has_clk else 4
|
||||||
|
with open(fp) as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
parts = line.split(",")
|
||||||
|
if len(parts) < min_parts:
|
||||||
|
continue
|
||||||
|
if int(parts[1]) not in test_users:
|
||||||
|
continue
|
||||||
|
logid = int(parts[0])
|
||||||
|
fs = 5 if has_clk else 4
|
||||||
|
signs, slots = [], []
|
||||||
|
for pair in parts[fs:]:
|
||||||
|
if ":" in pair:
|
||||||
|
s, sl = pair.split(":", 1)
|
||||||
|
signs.append(int(s))
|
||||||
|
slots.append(int(sl))
|
||||||
|
item_dict[logid] = {
|
||||||
|
"signs": np.array(signs, dtype=np.int64),
|
||||||
|
"slots": np.array(slots, dtype=np.int64),
|
||||||
|
}
|
||||||
|
return item_dict
|
||||||
|
|
||||||
复用 CTRTestSeqDataset + collate + model.rep_encoder,保证与 model(batch) 内的
|
|
||||||
RepEncoder 输出逐位一致。注意:必须用与评测端一致的 max_feasign_per_slot(基线为 {1:2}),
|
def build_rep_cache(model, item_dict, max_feasign_per_slot, device, chunk=4000, max_slot_id=28):
|
||||||
否则缓存的 item 向量与 batch 实际特征不符。
|
"""直接从 item_dict 逐 item 预计算 RepEncoder 向量(不建 CTRTestSeqDataset,省内存)。
|
||||||
|
|
||||||
|
每个 item 作为一个 segment,逐 slot 拼 values/offsets,跑 model.rep_encoder,
|
||||||
|
与 model(batch) 内的 RepEncoder 输出逐位一致。必须用与评测端一致的
|
||||||
|
max_feasign_per_slot(基线 {1:2}),否则缓存向量与 batch 实际特征不符。
|
||||||
"""
|
"""
|
||||||
ds = CTRTestSeqDataset(
|
logids_sorted = sorted(item_dict.keys())
|
||||||
test_logids_ordered=test_logids_ordered, item_dict=item_dict,
|
emb_chunks = []
|
||||||
user_seq=user_seq, max_feasign_per_slot=max_feasign_per_slot, max_ctx_len=None)
|
|
||||||
loader = DataLoader(ds, batch_size=batch_users, shuffle=False, num_workers=0,
|
|
||||||
collate_fn=make_collate_fn(ds.max_slot_id))
|
|
||||||
logid_chunks, emb_chunks = [], []
|
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
for batch in loader:
|
for i in range(0, len(logids_sorted), chunk):
|
||||||
batch = move_batch_to_device(batch, device)
|
cl = logids_sorted[i:i + chunk]
|
||||||
rep = model.rep_encoder(batch) # [num_tokens, d_model]
|
slot_vals = {s: [] for s in range(1, max_slot_id + 1)}
|
||||||
logid_chunks.append(batch["logid"].to(device))
|
slot_offs = {s: [0] for s in range(1, max_slot_id + 1)}
|
||||||
emb_chunks.append(rep)
|
for lid in cl:
|
||||||
logids = torch.cat(logid_chunks)
|
rec = item_dict[lid]
|
||||||
|
by = defaultdict(list)
|
||||||
|
for s, sl in zip(rec["signs"].tolist(), rec["slots"].tolist()):
|
||||||
|
by[sl].append(s)
|
||||||
|
for slot in range(1, max_slot_id + 1):
|
||||||
|
ss = by.get(slot, [])
|
||||||
|
if max_feasign_per_slot and max_feasign_per_slot.get(slot, -1) != -1:
|
||||||
|
ss = ss[:max_feasign_per_slot[slot]]
|
||||||
|
slot_vals[slot].extend(ss)
|
||||||
|
slot_offs[slot].append(len(slot_vals[slot]))
|
||||||
|
batch = {slot: (torch.tensor(slot_vals[slot], dtype=torch.long, device=device),
|
||||||
|
torch.tensor(slot_offs[slot], dtype=torch.long, device=device))
|
||||||
|
for slot in range(1, max_slot_id + 1)}
|
||||||
|
emb_chunks.append(model.rep_encoder(batch)) # [len(cl), d_model]
|
||||||
|
logids = torch.tensor(logids_sorted, dtype=torch.long, device=device) # 已有序
|
||||||
emb = torch.cat(emb_chunks)
|
emb = torch.cat(emb_chunks)
|
||||||
order = torch.argsort(logids)
|
model._rep_cache = (logids.contiguous(), emb.contiguous())
|
||||||
model._rep_cache = (logids[order].contiguous(), emb[order].contiguous())
|
|
||||||
return model._rep_cache
|
return model._rep_cache
|
||||||
|
|
||||||
|
|
||||||
@@ -840,13 +891,14 @@ def load_model(ckpt_path, device='cuda:0'):
|
|||||||
ds_dir = cand
|
ds_dir = cand
|
||||||
break
|
break
|
||||||
if ds_dir is not None:
|
if ds_dir is not None:
|
||||||
history = ds_dir / "history"
|
# 流式只加载测试用户的 item(避免全量 OOM),算完即释放
|
||||||
test_csv = ds_dir / "test.csv"
|
item_dict = _load_test_user_items(ds_dir)
|
||||||
files = (sorted(history.glob("*.csv")) if history.exists() else []) + [test_csv]
|
build_rep_cache(model, item_dict, {1: 2}, dev)
|
||||||
item_dict, user_seq = load_sample_files(files)
|
n_items = model._rep_cache[0].numel()
|
||||||
test_logids = list(load_logids_from_file(test_csv))
|
del item_dict
|
||||||
build_rep_cache(model, item_dict, user_seq, test_logids, {1: 2}, dev)
|
import gc
|
||||||
print(f"[INFO] rep cache built: {model._rep_cache[0].numel()} items")
|
gc.collect()
|
||||||
|
print(f"[INFO] rep cache built (stream-filtered): {n_items} items")
|
||||||
else:
|
else:
|
||||||
print("[INFO] dataset/ not found, skip rep precompute (fallback to in-batch)")
|
print("[INFO] dataset/ not found, skip rep precompute (fallback to in-batch)")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
Reference in New Issue
Block a user