feat: 预计算RepEncoder缓存,model(batch)按logid gather跳过embedding层

不计时的load_model里(或bench从batches)预计算所有item的context-free RepEncoder向量,
排序存(sorted_logids,emb);model(batch)用searchsorted gather、缺失回退现算。逐位等价。
预期 model(batch) 48s->~37s->~70。CONFIG.precompute_rep(eval默认True);bench --precompute-rep。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-15 17:06:56 +08:00
parent 2662da850c
commit 2004ad6bb8
2 changed files with 97 additions and 2 deletions
+23 -1
View File
@@ -209,8 +209,11 @@ def run_once(config_override=None, batch_size=50, max_batches=None,
if max_feasign_per_slot is None:
max_feasign_per_slot = {1: 2}
# 本地用已加载的过滤数据自建 rep 缓存,禁止 load_model 自动加载全量数据集
want_precompute = bool(config_override.pop("precompute_rep", False))
infer.CONFIG.update(config_override)
infer.CONFIG["sync_timing"] = True
infer.CONFIG["precompute_rep"] = False
cur = Path(__file__).parent
ref = cur / "dataset"
@@ -238,10 +241,25 @@ def run_once(config_override=None, batch_size=50, max_batches=None,
gc.collect()
model, dev = infer.load_model(ckpt_path=None)
cuda = (dev.type == "cuda")
# 本地从已建好的 batches 构造 rep 缓存(复用 batches、省内存;不计入计时)
if want_precompute:
lc, ec = [], []
with torch.inference_mode():
for b in batches:
bb = infer.move_batch_to_device(b, dev)
rep = model.rep_encoder(bb)
lc.append(bb["logid"].to(dev))
ec.append(rep)
logids = torch.cat(lc)
emb = torch.cat(ec)
order = torch.argsort(logids)
model._rep_cache = (logids[order].contiguous(), emb[order].contiguous())
print(f"[BENCH] rep cache built from batches: {logids.numel()} items")
logid2p = {}
t_sum = 0.0
cuda = (dev.type == "cuda")
with torch.inference_mode():
for b in batches:
b = infer.move_batch_to_device(b, dev)
@@ -300,6 +318,8 @@ def _parse_args():
ap.add_argument("--emb-fp16", action="store_true", help="Embedding表转FP16(查表带宽减半,测AUC)")
ap.add_argument("--dedup-emb", action="store_true", help="查表前对sign去重(减少大表随机访存)")
ap.add_argument("--sparse-pool", action="store_true", help="稀疏矩阵乘做池化(段内高重复时省)")
ap.add_argument("--precompute-rep", action="store_true",
help="预计算RepEncoder缓存,model(batch)跳过embedding层(从batches自建)")
ap.add_argument("--profile", type=int, default=None, metavar="N",
help="剖析前 N 个 batch,打印按 CUDA 耗时排序的算子表(定位瓶颈)")
ap.add_argument("--rebuild", action="store_true", help="强制重建过滤缓存")
@@ -337,6 +357,8 @@ if __name__ == "__main__":
cfg["dedup_embedding"] = True
if a.sparse_pool:
cfg["sparse_pool"] = True
if a.precompute_rep:
cfg["precompute_rep"] = True
if a.compile:
cfg["compile"] = True
if a.profile is not None: