perf: CTRTestSeqDataset 只枚举含测试样本的用户(跳过会被丢弃的用户)
提交版当前枚举全部 ~40770 用户,其中 ~87% 没有测试样本、前向输出被丢弃, 白算(86.5s 由此而来)。因果mask隔离用户,过滤不改变测试样本预测(AUC/PCOC不变), 预计延迟 86.5s→~15s,得分 58.86→~75。CONFIG.filter_test_users 可关。 Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
+19
-7
@@ -30,6 +30,7 @@ CONFIG = {
|
|||||||
"merge_threshold": 0.90, # 合并的余弦相似度阈值
|
"merge_threshold": 0.90, # 合并的余弦相似度阈值
|
||||||
"signid_mode": "clamp", # "clamp" 或 "modulo":处理超界 sign id 的方式
|
"signid_mode": "clamp", # "clamp" 或 "modulo":处理超界 sign id 的方式
|
||||||
"sync_timing": False, # bench 里设 True,做 torch.cuda.synchronize 真实计时
|
"sync_timing": False, # bench 里设 True,做 torch.cuda.synchronize 真实计时
|
||||||
|
"filter_test_users": True, # 只处理含测试样本的用户(跳过会被丢弃的用户,省算力)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -165,11 +166,22 @@ class CTRTestSeqDataset(Dataset):
|
|||||||
self.max_ctx_len = max_ctx_len
|
self.max_ctx_len = max_ctx_len
|
||||||
self.pred_logids = set(test_logids_ordered) if test_logids_ordered else set()
|
self.pred_logids = set(test_logids_ordered) if test_logids_ordered else set()
|
||||||
|
|
||||||
|
# 只处理“含测试样本的用户”:其余用户的前向输出会被丢弃,跳过以省算力。
|
||||||
|
# 不同用户被因果 mask 完全隔离,过滤不改变任何测试样本的预测(AUC/PCOC 不变)。
|
||||||
|
keep_users = None
|
||||||
|
if CONFIG.get("filter_test_users", True) and self.pred_logids:
|
||||||
|
keep_users = {rec['userid'] for logid, rec in item_dict.items()
|
||||||
|
if logid in self.pred_logids}
|
||||||
|
|
||||||
self.user_items = defaultdict(list)
|
self.user_items = defaultdict(list)
|
||||||
|
max_sign = 0
|
||||||
for logid, rec in item_dict.items():
|
for logid, rec in item_dict.items():
|
||||||
userid = rec['userid']
|
userid = rec['userid']
|
||||||
|
if keep_users is not None and userid not in keep_users:
|
||||||
|
continue
|
||||||
|
signs_list = rec['signs'].tolist()
|
||||||
feasign = defaultdict(list)
|
feasign = defaultdict(list)
|
||||||
for slot, sign in zip(rec['slots'].tolist(), rec['signs'].tolist()):
|
for slot, sign in zip(rec['slots'].tolist(), signs_list):
|
||||||
feasign[slot].append(sign)
|
feasign[slot].append(sign)
|
||||||
if max_feasign_per_slot is not None:
|
if max_feasign_per_slot is not None:
|
||||||
feasign = {slot: signs[:max_feasign_per_slot[slot]]
|
feasign = {slot: signs[:max_feasign_per_slot[slot]]
|
||||||
@@ -178,16 +190,16 @@ class CTRTestSeqDataset(Dataset):
|
|||||||
feasign = dict(feasign)
|
feasign = dict(feasign)
|
||||||
label = rec['clk']
|
label = rec['clk']
|
||||||
self.user_items[userid].append((logid, feasign, label))
|
self.user_items[userid].append((logid, feasign, label))
|
||||||
|
if signs_list:
|
||||||
|
m = max(signs_list)
|
||||||
|
if m > max_sign:
|
||||||
|
max_sign = m
|
||||||
|
|
||||||
self.user_ids = sorted(self.user_items.keys())
|
self.user_ids = sorted(self.user_items.keys())
|
||||||
self.num_users = len(self.user_ids)
|
self.num_users = len(self.user_ids)
|
||||||
self.total_samples = len(item_dict)
|
self.total_samples = sum(len(v) for v in self.user_items.values())
|
||||||
|
|
||||||
all_signs = set()
|
|
||||||
for rec in item_dict.values():
|
|
||||||
all_signs.update(rec['signs'].tolist())
|
|
||||||
self.max_slot_id = 28
|
self.max_slot_id = 28
|
||||||
self.max_sign_id = max(all_signs) if all_signs else 0
|
self.max_sign_id = max_sign
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.num_users
|
return self.num_users
|
||||||
|
|||||||
Reference in New Issue
Block a user