From a7234e577a7158cb7c0a0094e8156cc82432eb43 Mon Sep 17 00:00:00 2001 From: OwnerSunshine530 Date: Sun, 14 Jun 2026 22:21:11 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20CTRTestSeqDataset=20=E5=8F=AA=E6=9E=9A?= =?UTF-8?q?=E4=B8=BE=E5=90=AB=E6=B5=8B=E8=AF=95=E6=A0=B7=E6=9C=AC=E7=9A=84?= =?UTF-8?q?=E7=94=A8=E6=88=B7=EF=BC=88=E8=B7=B3=E8=BF=87=E4=BC=9A=E8=A2=AB?= =?UTF-8?q?=E4=B8=A2=E5=BC=83=E7=9A=84=E7=94=A8=E6=88=B7=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 提交版当前枚举全部 ~40770 用户,其中 ~87% 没有测试样本、前向输出被丢弃, 白算(86.5s 由此而来)。因果mask隔离用户,过滤不改变测试样本预测(AUC/PCOC不变), 预计延迟 86.5s→~15s,得分 58.86→~75。CONFIG.filter_test_users 可关。 Co-Authored-By: Claude Opus 4.8 --- 代码/code/infer.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/代码/code/infer.py b/代码/code/infer.py index 7d3d131..ccd3d45 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -30,6 +30,7 @@ CONFIG = { "merge_threshold": 0.90, # 合并的余弦相似度阈值 "signid_mode": "clamp", # "clamp" 或 "modulo":处理超界 sign id 的方式 "sync_timing": False, # bench 里设 True,做 torch.cuda.synchronize 真实计时 + "filter_test_users": True, # 只处理含测试样本的用户(跳过会被丢弃的用户,省算力) } @@ -165,11 +166,22 @@ class CTRTestSeqDataset(Dataset): self.max_ctx_len = max_ctx_len self.pred_logids = set(test_logids_ordered) if test_logids_ordered else set() + # 只处理“含测试样本的用户”:其余用户的前向输出会被丢弃,跳过以省算力。 + # 不同用户被因果 mask 完全隔离,过滤不改变任何测试样本的预测(AUC/PCOC 不变)。 + keep_users = None + if CONFIG.get("filter_test_users", True) and self.pred_logids: + keep_users = {rec['userid'] for logid, rec in item_dict.items() + if logid in self.pred_logids} + self.user_items = defaultdict(list) + max_sign = 0 for logid, rec in item_dict.items(): userid = rec['userid'] + if keep_users is not None and userid not in keep_users: + continue + signs_list = rec['signs'].tolist() feasign = defaultdict(list) - for slot, sign in zip(rec['slots'].tolist(), rec['signs'].tolist()): + for slot, sign in zip(rec['slots'].tolist(), signs_list): feasign[slot].append(sign) if max_feasign_per_slot is not None: feasign = {slot: signs[:max_feasign_per_slot[slot]] @@ -178,16 +190,16 @@ class CTRTestSeqDataset(Dataset): feasign = dict(feasign) label = rec['clk'] self.user_items[userid].append((logid, feasign, label)) + if signs_list: + m = max(signs_list) + if m > max_sign: + max_sign = m self.user_ids = sorted(self.user_items.keys()) self.num_users = len(self.user_ids) - self.total_samples = len(item_dict) - - all_signs = set() - for rec in item_dict.values(): - all_signs.update(rec['signs'].tolist()) + self.total_samples = sum(len(v) for v in self.user_items.values()) self.max_slot_id = 28 - self.max_sign_id = max(all_signs) if all_signs else 0 + self.max_sign_id = max_sign def __len__(self): return self.num_users