feat: 预计算RepEncoder缓存,model(batch)按logid gather跳过embedding层

不计时的load_model里(或bench从batches)预计算所有item的context-free RepEncoder向量,
排序存(sorted_logids,emb);model(batch)用searchsorted gather、缺失回退现算。逐位等价。
预期 model(batch) 48s->~37s->~70。CONFIG.precompute_rep(eval默认True);bench --precompute-rep。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-15 17:06:56 +08:00
parent 2662da850c
commit 2004ad6bb8
2 changed files with 97 additions and 2 deletions
+23 -1
View File
@@ -209,8 +209,11 @@ def run_once(config_override=None, batch_size=50, max_batches=None,
if max_feasign_per_slot is None: if max_feasign_per_slot is None:
max_feasign_per_slot = {1: 2} max_feasign_per_slot = {1: 2}
# 本地用已加载的过滤数据自建 rep 缓存,禁止 load_model 自动加载全量数据集
want_precompute = bool(config_override.pop("precompute_rep", False))
infer.CONFIG.update(config_override) infer.CONFIG.update(config_override)
infer.CONFIG["sync_timing"] = True infer.CONFIG["sync_timing"] = True
infer.CONFIG["precompute_rep"] = False
cur = Path(__file__).parent cur = Path(__file__).parent
ref = cur / "dataset" ref = cur / "dataset"
@@ -238,10 +241,25 @@ def run_once(config_override=None, batch_size=50, max_batches=None,
gc.collect() gc.collect()
model, dev = infer.load_model(ckpt_path=None) model, dev = infer.load_model(ckpt_path=None)
cuda = (dev.type == "cuda")
# 本地从已建好的 batches 构造 rep 缓存(复用 batches、省内存;不计入计时)
if want_precompute:
lc, ec = [], []
with torch.inference_mode():
for b in batches:
bb = infer.move_batch_to_device(b, dev)
rep = model.rep_encoder(bb)
lc.append(bb["logid"].to(dev))
ec.append(rep)
logids = torch.cat(lc)
emb = torch.cat(ec)
order = torch.argsort(logids)
model._rep_cache = (logids[order].contiguous(), emb[order].contiguous())
print(f"[BENCH] rep cache built from batches: {logids.numel()} items")
logid2p = {} logid2p = {}
t_sum = 0.0 t_sum = 0.0
cuda = (dev.type == "cuda")
with torch.inference_mode(): with torch.inference_mode():
for b in batches: for b in batches:
b = infer.move_batch_to_device(b, dev) b = infer.move_batch_to_device(b, dev)
@@ -300,6 +318,8 @@ def _parse_args():
ap.add_argument("--emb-fp16", action="store_true", help="Embedding表转FP16(查表带宽减半,测AUC)") ap.add_argument("--emb-fp16", action="store_true", help="Embedding表转FP16(查表带宽减半,测AUC)")
ap.add_argument("--dedup-emb", action="store_true", help="查表前对sign去重(减少大表随机访存)") ap.add_argument("--dedup-emb", action="store_true", help="查表前对sign去重(减少大表随机访存)")
ap.add_argument("--sparse-pool", action="store_true", help="稀疏矩阵乘做池化(段内高重复时省)") ap.add_argument("--sparse-pool", action="store_true", help="稀疏矩阵乘做池化(段内高重复时省)")
ap.add_argument("--precompute-rep", action="store_true",
help="预计算RepEncoder缓存,model(batch)跳过embedding层(从batches自建)")
ap.add_argument("--profile", type=int, default=None, metavar="N", ap.add_argument("--profile", type=int, default=None, metavar="N",
help="剖析前 N 个 batch,打印按 CUDA 耗时排序的算子表(定位瓶颈)") help="剖析前 N 个 batch,打印按 CUDA 耗时排序的算子表(定位瓶颈)")
ap.add_argument("--rebuild", action="store_true", help="强制重建过滤缓存") ap.add_argument("--rebuild", action="store_true", help="强制重建过滤缓存")
@@ -337,6 +357,8 @@ if __name__ == "__main__":
cfg["dedup_embedding"] = True cfg["dedup_embedding"] = True
if a.sparse_pool: if a.sparse_pool:
cfg["sparse_pool"] = True cfg["sparse_pool"] = True
if a.precompute_rep:
cfg["precompute_rep"] = True
if a.compile: if a.compile:
cfg["compile"] = True cfg["compile"] = True
if a.profile is not None: if a.profile is not None:
+73
View File
@@ -55,6 +55,8 @@ CONFIG = {
"dedup_embedding": True, # True=查表前对sign去重(只查唯一值再展开),本地7.80->6.49s,AUC逐位等价 "dedup_embedding": True, # True=查表前对sign去重(只查唯一值再展开),本地7.80->6.49s,AUC逐位等价
"sparse_pool": False, # True=用(段×唯一)稀疏矩阵乘做池化,避免materialize整个[M,512](段内高重复时省) "sparse_pool": False, # True=用(段×唯一)稀疏矩阵乘做池化,避免materialize整个[M,512](段内高重复时省)
"compile": False, # 是否 torch.compile(实测慢5×,勿开) "compile": False, # 是否 torch.compile(实测慢5×,勿开)
"precompute_rep": True, # True=不计时的load_model里预计算所有item的RepEncoder向量,
# model(batch)按logid gather缓存、跳过embedding层(逐位等价)
} }
@@ -624,6 +626,19 @@ class CTRModel(nn.Module):
self.seq_encoder = seq_encoder self.seq_encoder = seq_encoder
self.d_model = d_model self.d_model = d_model
self.linear = nn.Linear(d_model, 1) self.linear = nn.Linear(d_model, 1)
self._rep_cache = None # (sorted_logids[N], rep_emb[N, d_model]) 或 None
def _gather_rep(self, batch):
"""有预计算缓存时,按 logid gather 出 RepEncoder 向量(跳过 embedding 层)。
searchsorted+gather 全在 GPU、无同步。任何缺失 logid → 回退现算整个 batch。"""
sorted_logids, rep_emb = self._rep_cache
logids = batch["logid"].to(sorted_logids.device)
rows = torch.searchsorted(sorted_logids, logids)
rows = rows.clamp(max=sorted_logids.numel() - 1)
hit = sorted_logids[rows] == logids
if bool(hit.all()): # 命中全部 → 直接 gather
return rep_emb[rows].to(self.linear.weight.dtype)
return self.rep_encoder(batch) # 有缺失 → 安全回退
def get_sequence_causal_mask(self, seq_info): def get_sequence_causal_mask(self, seq_info):
lengths = seq_info[1:] - seq_info[:-1] lengths = seq_info[1:] - seq_info[:-1]
@@ -673,6 +688,9 @@ class CTRModel(nn.Module):
return create_block_mask(mask_mod, B=None, H=None, Q_LEN=S, KV_LEN=S, device=device) return create_block_mask(mask_mod, B=None, H=None, Q_LEN=S, KV_LEN=S, device=device)
def forward(self, batch): def forward(self, batch):
if self._rep_cache is not None:
seq_input = self._gather_rep(batch) # 用预计算缓存,跳过 embedding 层
else:
seq_input = self.rep_encoder(batch) seq_input = self.rep_encoder(batch)
user_offsets = batch["user_offsets"] user_offsets = batch["user_offsets"]
attn = _resolve_attn(seq_input.device) attn = _resolve_attn(seq_input.device)
@@ -697,6 +715,38 @@ class CTRModel(nn.Module):
return pred_logits, moe_loss return pred_logits, moe_loss
# ============================================================
# RepEncoder 预计算缓存
# ============================================================
def build_rep_cache(model, item_dict, user_seq, test_logids_ordered,
max_feasign_per_slot, device, batch_users=200):
"""预计算所有 item 的 RepEncoder 向量(context-free),按 logid 排序存入 model._rep_cache。
复用 CTRTestSeqDataset + collate + model.rep_encoder,保证与 model(batch) 内的
RepEncoder 输出逐位一致。注意:必须用与评测端一致的 max_feasign_per_slot(基线为 {1:2}),
否则缓存的 item 向量与 batch 实际特征不符。
"""
ds = CTRTestSeqDataset(
test_logids_ordered=test_logids_ordered, item_dict=item_dict,
user_seq=user_seq, max_feasign_per_slot=max_feasign_per_slot, max_ctx_len=None)
loader = DataLoader(ds, batch_size=batch_users, shuffle=False, num_workers=0,
collate_fn=make_collate_fn(ds.max_slot_id))
logid_chunks, emb_chunks = [], []
model.eval()
with torch.inference_mode():
for batch in loader:
batch = move_batch_to_device(batch, device)
rep = model.rep_encoder(batch) # [num_tokens, d_model]
logid_chunks.append(batch["logid"].to(device))
emb_chunks.append(rep)
logids = torch.cat(logid_chunks)
emb = torch.cat(emb_chunks)
order = torch.argsort(logids)
model._rep_cache = (logids[order].contiguous(), emb[order].contiguous())
return model._rep_cache
# ============================================================ # ============================================================
# 模型加载入口 # 模型加载入口
# ============================================================ # ============================================================
@@ -779,6 +829,29 @@ def load_model(ckpt_path, device='cuda:0'):
print(f"[INFO] attention={_resolve_attn(dev)}, " print(f"[INFO] attention={_resolve_attn(dev)}, "
f"moe={'dense' if CONFIG.get('vectorize_moe', True) else 'loop'}") f"moe={'dense' if CONFIG.get('vectorize_moe', True) else 'loop'}")
# === 预计算 RepEncoder 缓存(不计时阶段)===
if CONFIG.get("precompute_rep", False) and model._rep_cache is None:
try:
ds_dir = None
for cand in (Path(ckpt_path).parent / "dataset", Path("dataset"),
Path(__file__).parent / "dataset"):
if cand.exists():
ds_dir = cand
break
if ds_dir is not None:
history = ds_dir / "history"
test_csv = ds_dir / "test.csv"
files = (sorted(history.glob("*.csv")) if history.exists() else []) + [test_csv]
item_dict, user_seq = load_sample_files(files)
test_logids = list(load_logids_from_file(test_csv))
build_rep_cache(model, item_dict, user_seq, test_logids, {1: 2}, dev)
print(f"[INFO] rep cache built: {model._rep_cache[0].numel()} items")
else:
print("[INFO] dataset/ not found, skip rep precompute (fallback to in-batch)")
except Exception as e:
print(f"[WARNING] rep precompute failed ({e}), fallback to in-batch RepEncoder")
model._rep_cache = None
if CONFIG.get("compile", False): if CONFIG.get("compile", False):
try: try:
model = torch.compile(model, dynamic=True) model = torch.compile(model, dynamic=True)