import sys import os # 获取当前环境脚本所在目录或指定绝对路径 if os.path.exists("../libraries"): lib_path = os.path.abspath("../libraries") sys.path.append(lib_path) import math import argparse from pathlib import Path from collections import defaultdict import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from tqdm import tqdm # FlexAttention(块对角因果注意力,需 PyTorch 2.5+ 且 GPU 计算能力 >= 8.0 / Ampere) try: from torch.nn.attention.flex_attention import flex_attention, create_block_mask _HAS_FLEX = True except Exception: flex_attention = None create_block_mask = None _HAS_FLEX = False # ============================================================ # 实验配置开关板 # 提交时保持下面的默认值 = 当前最优行为;评测系统不碰它,按默认值跑。 # bench.py 会在 import 之后用 infer.CONFIG.update(...) 覆盖这些值。 # ============================================================ CONFIG = { "fp16": True, # True=半精度推理;False=FP32 参考跑(确立 AUC 天花板) "keep_fp32_modules": (), # fp16 下仍保留 FP32 的子模块名前缀,如 ("linear",) "expert_merge": True, # 是否做 expert 权重相似度合并 "merge_threshold": 0.90, # 合并的余弦相似度阈值 "signid_mode": "clamp", # "clamp" 或 "modulo":处理超界 sign id 的方式 "sync_timing": False, # bench 里设 True,做 torch.cuda.synchronize 真实计时 "filter_test_users": True, # 只处理含测试样本的用户(跳过会被丢弃的用户,省算力) # 实测:varlen 本地快(10.28s)但评测端慢(148s,嵌套张量构造开销随batch数放大)→已退回。 # sdpa 是评测端验证最快(89.96s/58.86)。flex/compile/小batch/varlen 在评测端都更差。 # attn: "chunked"(按用户分块SDPA,降O(S²),本地14.25->7.92s) / "sdpa"(稠密mask) / 其它对照 "attn": "chunked", "chunk_users": 4, # chunked 每块用户数(本地 4 最快 6.18s;再小收益递减) # 稠密MoE去掉了 model(batch) 内唯一的同步点(MoE循环的.nonzero())。若评测计时不 # synchronize,去掉同步点可能让被计时的 model(batch) 大幅缩短。本地force-sync看不出, # 须靠提交验证。AUC中性、MoE仅占2%算力故风险极低。 "vectorize_moe": True, # True=稠密向量化MoE(无同步点);False=原逐expert循环(.nonzero同步) "fuse_embedding": True, # True=28个slot的查表+池化融合为1次(减per-batch kernel启动) "syncfree_mask": True, # True=用searchsorted构造因果mask(无同步);False=repeat_interleave(同步) "emb_fp16": True, # True=Embedding表转FP16(查表带宽减半,实测AUC 0.75932≈无损) "dedup_embedding": True, # True=查表前对sign去重(只查唯一值再展开),本地7.80->6.49s,AUC逐位等价 "compile": False, # 是否 torch.compile(实测慢5×,勿开) } def _resolve_attn(device): """解析实际使用的注意力实现。flex 需 SM80+ 且可用,否则回退 sdpa。""" attn = CONFIG.get("attn", "sdpa") if attn == "flex": if not _HAS_FLEX: return "sdpa" if device is not None and device.type == "cuda": try: if torch.cuda.get_device_capability(device)[0] < 8: return "sdpa" except Exception: return "sdpa" return attn def _force_fp32_io(module): """让某个模块在 FP16 模型里以 FP32 计算:输入转 FP32、输出转回 FP16。 用于 keep_fp32_modules 指定的精度敏感层(如最终输出头、LayerNorm)。""" module.float() def _pre(m, args): return tuple( a.float() if torch.is_tensor(a) and a.is_floating_point() else a for a in args ) def _post(m, args, output): if torch.is_tensor(output) and output.is_floating_point(): return output.half() return output module.register_forward_pre_hook(_pre) module.register_forward_hook(_post) # ============================================================ # 数据加载(来自 train/dataset.py) # ============================================================ def _detect_has_clk(file_path): """检测 CSV 文件是否包含 clk 列(5列 vs 4列格式)。 5列格式: logid,userid,adid,clk,timestamp,sign:slot... 4列格式: logid,userid,adid,timestamp,sign:slot... 通过第5个字段是否包含 ':' 来判断:有 ':' 说明已经是 sign:slot,即无 clk 列。 """ with open(file_path, 'r') as f: for line in f: line = line.strip() if not line: continue parts = line.split(',') if len(parts) >= 5: return ':' not in parts[4] return False return False def load_sample_files(sample_files_list): """加载 CSV sample 文件,返回 item_dict 和 user_seq。 自动检测每个文件是 5列(含clk)还是 4列(无clk)格式。 """ sample_files = sorted([Path(f) for f in sample_files_list]) print(f'[INFO] loading {len(sample_files)} files: {[str(f) for f in sample_files]}') item_dict = {} user_logs = defaultdict(list) for sample_file in tqdm(sample_files, desc='Loading sample files'): has_clk = _detect_has_clk(sample_file) min_parts = 5 if has_clk else 4 print(f' {sample_file.name}: has_clk={has_clk}') with open(sample_file, 'r') as f: for line in f: line = line.strip() if not line: continue parts = line.split(',') if len(parts) < min_parts: continue logid = int(parts[0]) userid = int(parts[1]) adid = int(parts[2]) if has_clk: clk = int(parts[3]) timestamp = int(parts[4]) feat_start = 5 else: clk = 0 timestamp = int(parts[3]) feat_start = 4 signs = [] slots = [] for pair in parts[feat_start:]: if ':' in pair: s, sl = pair.split(':', 1) signs.append(int(s)) slots.append(int(sl)) item_dict[logid] = { 'logid': logid, 'userid': userid, 'adid': adid, 'clk': clk, 'timestamp': timestamp, 'signs': np.array(signs, dtype=np.int64), 'slots': np.array(slots, dtype=np.int64), } user_logs[userid].append((timestamp, logid)) user_seq = {} for userid, logs in user_logs.items(): logs.sort(key=lambda x: x[0]) user_seq[userid] = [logid for _, logid in logs] print(f'[INFO] loaded {len(item_dict)} records, {len(user_seq)} users') return item_dict, user_seq def load_logids_from_file(file_path): """快速读取一个 sample 文件中的所有 logid""" logids = set() with open(file_path, 'r') as f: for line in f: line = line.strip() if not line: continue comma = line.index(',') logids.add(int(line[:comma])) return logids class CTRTestSeqDataset(Dataset): """按用户组织的 CTR 测试数据集(对齐评测接口)""" def __init__(self, test_logids_ordered, item_dict, user_seq=None, max_feasign_per_slot=None, max_ctx_len=None): super().__init__() self.item_dict = item_dict self.user_seq = user_seq if user_seq else {} self.max_feasign_per_slot = max_feasign_per_slot self.max_ctx_len = max_ctx_len self.pred_logids = set(test_logids_ordered) if test_logids_ordered else set() # 只处理“含测试样本的用户”:其余用户的前向输出会被丢弃,跳过以省算力。 # 不同用户被因果 mask 完全隔离,过滤不改变任何测试样本的预测(AUC/PCOC 不变)。 keep_users = None if CONFIG.get("filter_test_users", True) and self.pred_logids: keep_users = {rec['userid'] for logid, rec in item_dict.items() if logid in self.pred_logids} self.user_items = defaultdict(list) max_sign = 0 for logid, rec in item_dict.items(): userid = rec['userid'] if keep_users is not None and userid not in keep_users: continue signs_list = rec['signs'].tolist() feasign = defaultdict(list) for slot, sign in zip(rec['slots'].tolist(), signs_list): feasign[slot].append(sign) if max_feasign_per_slot is not None: feasign = {slot: signs[:max_feasign_per_slot[slot]] if max_feasign_per_slot.get(slot, -1) != -1 else signs for slot, signs in feasign.items()} feasign = dict(feasign) label = rec['clk'] self.user_items[userid].append((logid, feasign, label)) if signs_list: m = max(signs_list) if m > max_sign: max_sign = m self.user_ids = sorted(self.user_items.keys()) self.num_users = len(self.user_ids) self.total_samples = sum(len(v) for v in self.user_items.values()) self.max_slot_id = 28 self.max_sign_id = max_sign def __len__(self): return self.num_users def __getitem__(self, index): userid = self.user_ids[index] items = self.user_items[userid] if self.user_seq and userid in self.user_seq: seq_order = {logid: i for i, logid in enumerate(self.user_seq[userid])} items.sort(key=lambda x: seq_order.get(x[0], x[0])) else: items.sort(key=lambda x: x[0]) feasigns = [] labels = [] logids = [] for logid, feasign, label in items: logids.append(logid) feasigns.append(feasign) labels.append(label) return { 'userid': userid, 'logids': logids, 'feasigns': feasigns, 'labels': labels, 'pred_mask': [1 if logid in self.pred_logids else 0 for logid in logids], } def make_collate_fn(max_slot_id): def collate_user_batch(batch): all_userids = [] all_logids = [] all_labels = [] all_pred_masks = [] all_feasigns = [] user_offsets = [0] for item in batch: for i, logid in enumerate(item['logids']): all_userids.append(item['userid']) all_logids.append(logid) all_labels.append(item['labels'][i]) all_pred_masks.append(item['pred_mask'][i]) all_feasigns.append(item['feasigns'][i]) user_offsets.append(len(all_labels)) slot_data = {} for slot in range(1, max_slot_id + 1): values = [] offsets = [0] for feasign in all_feasigns: if slot in feasign: values.extend(feasign[slot]) offsets.append(len(values)) slot_data[slot] = ( torch.tensor(values, dtype=torch.long), torch.tensor(offsets, dtype=torch.long), ) result = { 'userid': torch.tensor(all_userids, dtype=torch.long), 'logid': torch.tensor(all_logids, dtype=torch.long), 'label': torch.tensor(all_labels, dtype=torch.float32), 'pred_mask': torch.tensor(all_pred_masks, dtype=torch.bool), 'user_offsets': torch.tensor(user_offsets, dtype=torch.long), } result.update(slot_data) return result return collate_user_batch # ============================================================ # 模型定义(来自 main.py) # ============================================================ def move_batch_to_device(batch, device): if isinstance(batch, dict): return {k: move_batch_to_device(v, device) for k, v in batch.items()} elif isinstance(batch, (list, tuple)): return [move_batch_to_device(x, device) for x in batch] elif torch.is_tensor(batch): x = batch.to(device) # 浮点 tensor → FP16,整数 tensor 保持不变 if x.dtype == torch.float32: x = x.half() return x else: return batch def _rep_forward_perslot(enc, batch): """原始逐 slot 实现(保留作数值等价对照/回退)。""" pooled_embs = [] max_idx = enc.emb.num_embeddings - 1 target_dtype = enc.input_norm.weight.dtype for i in range(enc.slot_num): values, offsets = batch[i + 1] offsets = offsets.to(values.device) values = enc._signid(values, max_idx) sign_emb = enc.emb(values).to(target_dtype) res = torch.segment_reduce(sign_emb, reduce='sum', offsets=offsets, initial=0) pooled_embs.append(res) fused_embs = torch.cat(pooled_embs, dim=1) return enc.linear(enc.input_norm(fused_embs)) class RepEncoder(nn.Module): def __init__(self, vocab_size, emb_dim, padding_idx=0, slot_num=0, d_model=0): super().__init__() self.emb = nn.Embedding(num_embeddings=vocab_size, embedding_dim=emb_dim, padding_idx=padding_idx) self.emb_dim = emb_dim self.slot_num = slot_num self.input_norm = nn.LayerNorm(slot_num * emb_dim) self.linear = nn.Linear(in_features=slot_num * emb_dim, out_features=d_model) def _signid(self, values, max_idx): if CONFIG["signid_mode"] == "modulo": return values % self.emb.num_embeddings # 取模哈希(与训练一致时用) return values.clamp(0, max_idx) # 超界 sign id 截断 def forward(self, batch): if not CONFIG.get("fuse_embedding", True): return _rep_forward_perslot(self, batch) max_idx = self.emb.num_embeddings - 1 target_dtype = self.input_norm.weight.dtype N = batch[1][1].numel() - 1 # 样本数(slot1 的 offsets 段数) # 把 28 个 slot 的 values 拼成一条,offsets 平移拼成覆盖 28*N 段的单一 offsets parts, ends, base = [], [], 0 for i in range(self.slot_num): values, offsets = batch[i + 1] offsets = offsets.to(values.device) parts.append(values) ends.append(offsets[1:] + base) # 该 slot 各样本的段尾(平移 base) base += values.numel() # numel 读 shape,不触发同步 cat_values = self._signid(torch.cat(parts), max_idx) seg = torch.cat([torch.zeros(1, dtype=torch.long, device=cat_values.device), torch.cat(ends)]) # [28*N + 1] if CONFIG.get("dedup_embedding", False): # 去重:只对唯一 sign 查大表,再按逆索引展开(数学逐位等价,省随机访存) uniq, inv = torch.unique(cat_values, return_inverse=True) emb = self.emb(uniq).to(target_dtype)[inv] else: emb = self.emb(cat_values).to(target_dtype) pooled = torch.segment_reduce(emb, reduce='sum', offsets=seg, initial=0) # [28*N, emb] pooled = pooled.view(self.slot_num, N, self.emb_dim).permute(1, 0, 2).reshape( N, self.slot_num * self.emb_dim) return self.linear(self.input_norm(pooled)) def _varlen_attention(q, k, v, user_offsets): """嵌套张量变长 flash 注意力:每个用户当独立序列、is_causal 块对角因果。 一个内核处理一 batch 内所有用户,无稠密 mask、无 padding 浪费、开销低。 q,k,v: [1, H, S, Dh];user_offsets: [B+1](S 上的用户边界)。返回 [1, H, S, Dh]。 """ _, H, S, Dh = q.shape offs = user_offsets.to(torch.int64) # [1,H,S,Dh] -> [S,H,Dh] qv = q.squeeze(0).transpose(0, 1).contiguous() kv = k.squeeze(0).transpose(0, 1).contiguous() vv = v.squeeze(0).transpose(0, 1).contiguous() # 按用户边界做 jagged 嵌套张量:[B, ragged, H, Dh] -> [B, H, ragged, Dh] qn = torch.nested.nested_tensor_from_jagged(qv, offsets=offs).transpose(1, 2) kn = torch.nested.nested_tensor_from_jagged(kv, offsets=offs).transpose(1, 2) vn = torch.nested.nested_tensor_from_jagged(vv, offsets=offs).transpose(1, 2) out = F.scaled_dot_product_attention(qn, kn, vn, is_causal=True) # [B,H,ragged,Dh] out = out.transpose(1, 2).values() # [S, H, Dh] return out.transpose(0, 1).unsqueeze(0).contiguous() # [1, H, S, Dh] def scaled_dot_product(q, k, v, extension): """注意力分发: - chunks → 按用户分块的 SDPA(每块块内因果,降 O(S²),无嵌套开销)。 - varlen_offsets → 嵌套张量变长 flash(评测端慢,仅对照)。 - block_mask → FlexAttention 块对角因果。 - mask(默认) → 标准 SDPA 稠密 mask(数学等价、已验证最快)。 """ if extension is not None and extension.get("chunks") is not None: outs = [] for s0, s1, m in extension["chunks"]: outs.append(F.scaled_dot_product_attention( q[:, :, s0:s1], k[:, :, s0:s1], v[:, :, s0:s1], attn_mask=m, dropout_p=0.0, is_causal=False)) return torch.cat(outs, dim=2) if extension is not None and extension.get("varlen_offsets") is not None: return _varlen_attention(q, k, v, extension["varlen_offsets"]) if extension is not None and extension.get("block_mask") is not None: return flex_attention(q, k, v, block_mask=extension["block_mask"]) if extension is not None and "mask" in extension: attn_mask = extension["mask"].to(device=q.device) else: attn_mask = None return F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False, ) class Expert(nn.Module): def __init__(self, d_model, dim_ff): super().__init__() self.fc1 = nn.Linear(d_model, dim_ff) self.fc2 = nn.Linear(dim_ff, d_model) def forward(self, x): return self.fc2(F.relu(self.fc1(x))) class TopKGate(nn.Module): def __init__(self, d_model, num_experts, k=2, noisy_gating=True): super().__init__() self.w_g = nn.Linear(d_model, num_experts) self.num_experts = num_experts self.k = k self.noisy_gating = noisy_gating def forward(self, x): # x: [B,S,D] logits = self.w_g(x) # [B,S,E] if self.noisy_gating and self.training: logits = logits + torch.randn_like(logits) * 0.1 probs = torch.softmax(logits, dim=-1) # [B,S,E] topk_score, topk_idx = torch.topk(probs, self.k, dim=-1) # [B,S,k] return topk_idx, topk_score, probs def _smoe_forward_loop(moe, x): """原始逐 expert 循环实现(保留作数值等价对照/回退)。""" B, S, D = x.shape topk_idx, topk_score, probs = moe.gate(x) out = torch.zeros_like(x) x_flat = x.reshape(-1, D) idx_flat = topk_idx.reshape(-1, moe.k) score_flat = topk_score.reshape(-1, moe.k) out_flat = out.reshape(-1, D) for i in range(moe.num_experts): mask = (idx_flat == i) token_idx, k_idx = mask.nonzero(as_tuple=True) if token_idx.numel() == 0: continue selected_x = x_flat[token_idx] expert_out = moe.experts[i](selected_x) weight = score_flat[token_idx, k_idx].unsqueeze(-1) out_flat[token_idx] += expert_out * weight importance = probs.sum(dim=(0, 1)) moe_loss = (importance.std() / (importance.mean() + 1e-6)) return out, moe_loss class SMoE(nn.Module): def __init__(self, d_model, dim_ff, num_experts, k=2): super().__init__() self.num_experts = num_experts self.k = k self.experts = nn.ModuleList([ Expert(d_model, dim_ff) for _ in range(num_experts) ]) self.gate = TopKGate(d_model, num_experts, k=k) self._stacked = False def _stack_weights(self): """把各 expert 的 fc1/fc2 权重堆叠成单一张量,供批量 matmul。 延迟到首次 forward 调用:此时已完成 expert 合并与 half()/to(device)。""" self.register_buffer("W1", torch.stack([e.fc1.weight for e in self.experts]).contiguous()) # [E,F,D] self.register_buffer("b1", torch.stack([e.fc1.bias for e in self.experts]).contiguous()) # [E,F] self.register_buffer("W2", torch.stack([e.fc2.weight for e in self.experts]).contiguous()) # [E,D,F] self.register_buffer("b2", torch.stack([e.fc2.bias for e in self.experts]).contiguous()) # [E,D] self._stacked = True def forward(self, x): # x: [B,S,D] if not CONFIG.get("vectorize_moe", True): return _smoe_forward_loop(self, x) if not self._stacked: self._stack_weights() B, S, D = x.shape topk_idx, topk_score, probs = self.gate(x) xf = x.reshape(-1, D) # [N, D] # 稠密计算所有 expert(GPU 友好、无 Python 循环/同步/gather-scatter): h = torch.einsum("nd,efd->enf", xf, self.W1) + self.b1.unsqueeze(1) # [E,N,F] h = F.relu(h) o = torch.einsum("enf,edf->end", h, self.W2) + self.b2.unsqueeze(1) # [E,N,D] # 按每个 token 的 top-k 选取并加权(与逐 expert 循环数学等价) o = o.permute(1, 0, 2) # [N, E, D] idx = topk_idx.reshape(-1, self.k) # [N, k] sc = topk_score.reshape(-1, self.k) # [N, k] sel = torch.gather(o, 1, idx.unsqueeze(-1).expand(-1, -1, D)) # [N, k, D] out = (sel * sc.unsqueeze(-1)).sum(dim=1).reshape(B, S, D) importance = probs.sum(dim=(0, 1)) # [E] moe_loss = (importance.std() / (importance.mean() + 1e-6)) return out, moe_loss class TransformerEncoder(nn.Module): def __init__(self, d_model, n_heads, num_layers, dim_ff, act="relu", attention_fn=scaled_dot_product): super().__init__() self.d_model = d_model self.n_heads = n_heads self.head_dim = d_model // n_heads self.num_layers = num_layers assert d_model % n_heads == 0 self.qkv_proj = nn.ModuleList([nn.Linear(d_model, 3 * d_model) for _ in range(num_layers)]) self.out_proj = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(num_layers)]) self.ffn1 = nn.ModuleList([nn.Linear(d_model, dim_ff) for _ in range(num_layers)]) self.ffn2 = nn.ModuleList([nn.Linear(dim_ff, d_model) for _ in range(num_layers)]) self.norm1 = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(num_layers)]) self.norm2 = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(num_layers)]) self.act = getattr(F, act) self.attention_fn = attention_fn self.moe = nn.ModuleList([ SMoE(d_model, dim_ff, num_experts=8, k=2) for _ in range(num_layers) ]) def forward(self, x, extension): x = x.unsqueeze(0) B, S, D = x.shape moe_loss_total = 0.0 for i in range(self.num_layers): residual = x x = self.norm1[i](x) qkv = self.qkv_proj[i](x) qkv = qkv.view(B, S, self.n_heads, 3 * self.head_dim) qkv = qkv.permute(0, 2, 1, 3) q, k, v = torch.split(qkv, self.head_dim, dim=-1) attn_out = self.attention_fn(q, k, v, extension) attn_out = attn_out.permute(0, 2, 1, 3).reshape(B, S, D) x = residual + self.out_proj[i](attn_out) residual = x x = self.norm2[i](x) moe_out, moe_loss = self.moe[i](x) x = residual + moe_out moe_loss_total = moe_loss_total + moe_loss return x, moe_loss_total class CTRModel(nn.Module): def __init__(self, rep_encoder, seq_encoder, d_model): super().__init__() self.rep_encoder = rep_encoder self.seq_encoder = seq_encoder self.d_model = d_model self.linear = nn.Linear(d_model, 1) def get_sequence_causal_mask(self, seq_info): lengths = seq_info[1:] - seq_info[:-1] lengths = lengths.view(-1) indices = torch.cumsum(torch.ones_like(lengths), dim=0) - 1 result = torch.repeat_interleave(indices, lengths) # repeats 是张量 → 同步 a = result.view(1, -1) - result.view(-1, 1) out_mask = torch.tril((a == 0).to(torch.int32)).bool() return out_mask def build_chunks(self, user_offsets, device): """把拼接序列按用户边界切成每块 ~chunk_users 个用户,返回 [(s0,s1,mask), ...]。 每块块内因果,注意力 O(块内S²) 远小于 O(总S²)。仅 1 次同步(读切分边界)。""" chunk_users = int(CONFIG.get("chunk_users", 16)) B = user_offsets.numel() - 1 # 用户数(读 shape,无同步) idx = list(range(0, B + 1, chunk_users)) if idx[-1] != B: idx.append(B) bounds = user_offsets[idx].tolist() # 1 次同步:取各块的 token 边界 chunks = [] for c in range(len(bounds) - 1): s0, s1 = bounds[c], bounds[c + 1] local_off = user_offsets[idx[c]:idx[c + 1] + 1] - s0 # 该块内的用户边界(GPU) m = self.causal_mask_syncfree(local_off, s1 - s0, device).unsqueeze(0).unsqueeze(0) chunks.append((s0, s1, m)) return chunks def causal_mask_syncfree(self, user_offsets, S, device): """与 get_sequence_causal_mask 等价,但用 searchsorted 求每个位置的用户号, 避免 repeat_interleave(张量repeats) 的隐式同步。""" pos = torch.arange(S, device=device) doc_id = torch.searchsorted(user_offsets[1:].contiguous(), pos, right=True) # [S],无同步 same = doc_id.view(-1, 1) == doc_id.view(1, -1) causal = pos.view(-1, 1) >= pos.view(1, -1) return same & causal def build_block_mask(self, user_offsets, S): """FlexAttention 块对角因果 mask:q 只能 attend 同一用户且 kv<=q 的位置。""" lengths = (user_offsets[1:] - user_offsets[:-1]).view(-1) device = user_offsets.device doc_id = torch.repeat_interleave( torch.arange(lengths.numel(), device=device), lengths) def mask_mod(b, h, q_idx, kv_idx): return (q_idx >= kv_idx) & (doc_id[q_idx] == doc_id[kv_idx]) return create_block_mask(mask_mod, B=None, H=None, Q_LEN=S, KV_LEN=S, device=device) def forward(self, batch): seq_input = self.rep_encoder(batch) user_offsets = batch["user_offsets"] attn = _resolve_attn(seq_input.device) if attn == "chunked": extension = {"chunks": self.build_chunks(user_offsets, seq_input.device)} elif attn == "varlen": extension = {"varlen_offsets": user_offsets} elif attn == "flex": S = seq_input.shape[0] # rep_encoder 输出 [S, D],S=总 token 数 extension = {"block_mask": self.build_block_mask(user_offsets, S)} else: if CONFIG.get("syncfree_mask", True): seq_mask = self.causal_mask_syncfree( user_offsets, seq_input.shape[0], seq_input.device) else: seq_mask = self.get_sequence_causal_mask(user_offsets) extension = {"mask": seq_mask.unsqueeze(0).unsqueeze(0)} encoder_output, moe_loss = self.seq_encoder(x=seq_input, extension=extension) encoder_output = encoder_output.squeeze(0) pred = self.linear(encoder_output) pred_logits = torch.clamp(pred, min=-15.0, max=15.0) return pred_logits, moe_loss # ============================================================ # 模型加载入口 # ============================================================ def load_model(ckpt_path, device='cuda:0'): """加载模型并返回,供 evaluation.py 调用。 Args: ckpt_path: checkpoint 文件路径(评测系统传入 Path 对象) device: 推理设备(默认 'cuda:0') Returns: (model, device) 元组 """ emb_dim = 512 slot_num = 28 vocab_size = 5000000 d_model = 512 n_heads = 8 num_layers = 8 dim_ff = 1024 rep_encoder = RepEncoder( vocab_size=vocab_size, emb_dim=emb_dim, padding_idx=0, slot_num=slot_num, d_model=d_model, ) seq_encoder = TransformerEncoder( d_model=d_model, n_heads=n_heads, num_layers=num_layers, dim_ff=dim_ff, act="relu", ) model = CTRModel(rep_encoder, seq_encoder, d_model=d_model) dev = torch.device(device if torch.cuda.is_available() else "cpu") # 加载 checkpoint # 若需要加载自定义修改的权重,请修改 479-488行逻辑,强制使用你文件夹中的权重 # 测评系统默认使用原始官方权重 if ckpt_path is None: ckpt_path = Path(__file__).parent / 'ckpt.pt' else: ckpt_path = Path(ckpt_path) if ckpt_path.exists(): ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False) model.load_state_dict(ckpt['model_state_dict']) print(f"[INFO] Loaded checkpoint from {ckpt_path} (epoch={ckpt.get('epoch', '?')})") if CONFIG["fp16"]: model = model.half() # 默认 Embedding 保留 FP32;emb_fp16=True 时保持 FP16(查表带宽减半) if not CONFIG.get("emb_fp16", False): model.rep_encoder.emb = model.rep_encoder.emb.to(torch.float32) # 额外保留 FP32 的精度敏感模块(输入/输出自动转换) for name, module in model.named_modules(): if name and any(name.startswith(p) for p in CONFIG["keep_fp32_modules"]): _force_fp32_io(module) emb_note = "emb=FP16" if CONFIG.get("emb_fp16", False) else "emb=FP32" print(f"[INFO] FP16 on; {emb_note}; extra FP32-kept: " f"{tuple(CONFIG['keep_fp32_modules'])}") else: model = model.float() print("[INFO] FP32 reference (no half)") # === 按 Expert 权重相似度合并冗余 expert === if CONFIG["expert_merge"]: _merge_experts(model, sim_threshold=CONFIG["merge_threshold"]) else: print("[INFO] expert_merge off") else: print(f"[WARNING] Checkpoint {ckpt_path} not found, using random weights") model.to(dev) model.eval() print(f"[INFO] attention={_resolve_attn(dev)}, " f"moe={'dense' if CONFIG.get('vectorize_moe', True) else 'loop'}") if CONFIG.get("compile", False): try: model = torch.compile(model, dynamic=True) print("[INFO] torch.compile enabled (dynamic=True)") except Exception as e: print(f"[WARNING] torch.compile failed ({e}), running eager") print(f"[INFO] Model ready. Device: {dev}") return model, dev def _merge_experts(model, sim_threshold=0.97): """按权重余弦相似度合并冗余 MoE expert。 合规:仅删除冗余部分,不改层数/维度/head/FFN channel。""" total_merged = 0 for layer_idx, moe in enumerate(model.seq_encoder.moe): num_exp = moe.num_experts if num_exp <= 1: continue # 1) 计算 8×8 成对相似度矩阵(fc1+fc2 平均余弦相似度) sim_matrix = torch.zeros(num_exp, num_exp) for i in range(num_exp): for j in range(i + 1, num_exp): w_i = torch.cat([ moe.experts[i].fc1.weight.data.flatten().float(), moe.experts[i].fc2.weight.data.flatten().float(), ]) w_j = torch.cat([ moe.experts[j].fc1.weight.data.flatten().float(), moe.experts[j].fc2.weight.data.flatten().float(), ]) sim = F.cosine_similarity(w_i.unsqueeze(0), w_j.unsqueeze(0)).item() sim_matrix[i, j] = sim sim_matrix[j, i] = sim # 2) 贪心聚类:从最高相似度 pair 开始,> threshold 则合并 parent = list(range(num_exp)) def find(x): while parent[x] != x: parent[x] = parent[parent[x]] x = parent[x] return x def union(a, b): ra, rb = find(a), find(b) if ra != rb: parent[rb] = ra # 按相似度降序遍历所有 pair pairs = [(sim_matrix[i, j].item(), i, j) for i in range(num_exp) for j in range(i + 1, num_exp)] pairs.sort(reverse=True) for sim_val, i, j in pairs: if sim_val > sim_threshold: union(i, j) # 3) 分组 clusters = {} for i in range(num_exp): root = find(i) clusters.setdefault(root, []).append(i) # 如果所有 cluster 都只有 1 个 expert,跳过 if all(len(c) == 1 for c in clusters.values()): continue # 4) 合并每个 cluster new_experts = [] new_gate_w = [] new_gate_b = [] for root, indices in clusters.items(): if len(indices) == 1: idx = indices[0] new_experts.append(moe.experts[idx]) new_gate_w.append(moe.gate.w_g.weight.data[idx].clone()) new_gate_b.append(moe.gate.w_g.bias.data[idx].clone()) else: # 平均权重 avg_fc1_w = sum(moe.experts[k].fc1.weight.data for k in indices) / len(indices) avg_fc1_b = sum(moe.experts[k].fc1.bias.data for k in indices) / len(indices) avg_fc2_w = sum(moe.experts[k].fc2.weight.data for k in indices) / len(indices) avg_fc2_b = sum(moe.experts[k].fc2.bias.data for k in indices) / len(indices) merged = Expert(moe.experts[0].fc1.in_features, moe.experts[0].fc1.out_features) merged.fc1.weight.data = avg_fc1_w merged.fc1.bias.data = avg_fc1_b merged.fc2.weight.data = avg_fc2_w merged.fc2.bias.data = avg_fc2_b new_experts.append(merged) # 平均 gate 权重 avg_gate_w = sum(moe.gate.w_g.weight.data[k].clone() for k in indices) / len(indices) avg_gate_b = sum(moe.gate.w_g.bias.data[k].clone() for k in indices) / len(indices) new_gate_w.append(avg_gate_w) new_gate_b.append(avg_gate_b) total_merged += len(indices) - 1 # 5) 更新 MoE 层 old_num = moe.num_experts new_num = len(new_experts) moe.experts = nn.ModuleList(new_experts) moe.num_experts = new_num new_k = min(moe.k, new_num) moe.k = new_k moe.gate.k = new_k # 替换 gate weight 和 bias moe.gate.w_g = nn.Linear(moe.gate.w_g.in_features, new_num).to( moe.gate.w_g.weight.device) moe.gate.w_g.weight.data = torch.stack(new_gate_w) moe.gate.w_g.bias.data = torch.stack(new_gate_b) moe.gate.num_experts = new_num print(f" Layer {layer_idx}: {old_num} → {new_num} experts " f"(merged {old_num - new_num}, k={moe.k})") if total_merged > 0: print(f"[INFO] Total merged experts: {total_merged}") else: print("[INFO] No experts merged (all below similarity threshold)") # ============================================================ # 打分工具(与 evaluation.py 保持一致) # ============================================================ def _read_predict(file_path): predictions = [] with open(file_path, 'r') as f: for line in f: line = line.strip() if line: predictions.append(float(line)) import numpy as np return np.array(predictions) def _read_label(file_path): labels = [] with open(file_path, 'r') as f: for line in f: line = line.strip() if line: parts = line.split(',') if len(parts) >= 4: labels.append(float(parts[3])) else: labels.append(float(line)) import numpy as np return np.array(labels) def _cal_score(predict_file, label_file, default_latency=0.0): import numpy as np from sklearn.metrics import roc_auc_score predictions = _read_predict(predict_file) labels = _read_label(label_file) unique_labels = np.unique(labels) if len(unique_labels) < 2: print('[WARNING] only one class present in labels, AUC is not defined, returning 0.5') auc = 0.5 else: auc = roc_auc_score(labels, predictions) mean_pred = np.mean(predictions) mean_label = np.mean(labels) if mean_label == 0: pcoc = 1.0 if mean_pred == 0 else float('inf') else: pcoc = float(mean_pred / mean_label) latency = default_latency base_latency = 300 score_latency = max(0.0, (base_latency - latency) / base_latency) if latency < base_latency else 0.0 if pcoc < 0.85 or pcoc > 1.15: score_model = 0.0 else: score_model = ((auc - 0.65) * 1000 + (0.15 - abs(pcoc - 1)) / 0.15 * 10) / 360 score_all = score_latency * 70 + score_model * 30 return { 'auc': auc, 'pcoc': pcoc, 'latency': latency, 'score_latency': score_latency, 'score_model': score_model, 'score_all': score_all, } # ============================================================ # main:直接运行 infer.py 进行测试 # ============================================================ def main(): import io import time import argparse parser = argparse.ArgumentParser() parser.add_argument('--ckpt', type=str, default=None, help='checkpoint 文件路径,默认使用同目录下的 ckpt.pt') args = parser.parse_args() cur_path = Path(__file__).parent.absolute() ref_dir = cur_path / 'dataset' history_dir = ref_dir / 'history' input_file = ref_dir / 'test.csv' output_file = Path('predict.txt') label_file = ref_dir / 'label_data.txt' # ----- 数据加载,优先从缓存读取 ----- MAX_SHARD_BYTES = 2 * 1024 * 1024 * 1024 # 2GB per shard batches_cache_dir = ref_dir / 'cached_batches' if batches_cache_dir.exists() and any(batches_cache_dir.glob('shard_*.pt')): print(f'[INFO] loading cached batch shards from {batches_cache_dir}') all_batches = [] shard_files = sorted(batches_cache_dir.glob('shard_*.pt'), key=lambda p: int(p.stem.split('_')[1])) for sf in shard_files: shard_batches = torch.load(sf, weights_only=False) all_batches.extend(shard_batches) print(f'[INFO] loaded {len(shard_batches)} batches from {sf.name}') print(f'[INFO] loaded {len(all_batches)} cached batches total from {len(shard_files)} shards') else: print('[INFO] start loading data from CSV') history_files = sorted(history_dir.glob('*.csv')) if history_dir.exists() else [] all_files = history_files + [input_file] item_dict, user_seq = load_sample_files(sample_files_list=all_files) test_pred_logids = load_logids_from_file(input_file) print(f'[INFO] Test pred logids count: {len(test_pred_logids)}') max_feasign_per_slot = {1: 2} test_dataset = CTRTestSeqDataset( test_logids_ordered=list(test_pred_logids), item_dict=item_dict, user_seq=user_seq, max_feasign_per_slot=max_feasign_per_slot, ) print(f'[INFO] num_users={test_dataset.num_users}, ' f'total_samples={test_dataset.total_samples}, ' f'pred_samples={len(test_pred_logids)}, ' f'max_sign_id={test_dataset.max_sign_id}') test_loader = DataLoader( test_dataset, batch_size=50, shuffle=False, num_workers=0, collate_fn=make_collate_fn(test_dataset.max_slot_id), ) # 收集 batches,预转 FP16 后按分片缓存 print('[INFO] collecting batches (pre-converting to FP16) and saving sharded cache...') all_batches = [] for batch in test_loader: batch = move_batch_to_device(batch, torch.device('cpu')) all_batches.append(batch) batches_cache_dir.mkdir(parents=True, exist_ok=True) shard_idx = 0 current_shard = [] current_size = 0 for batch in all_batches: buf = io.BytesIO() torch.save(batch, buf) batch_size_bytes = buf.tell() if current_shard and current_size + batch_size_bytes > MAX_SHARD_BYTES: shard_path = batches_cache_dir / f'shard_{shard_idx:04d}.pt' torch.save(current_shard, shard_path) print(f'[INFO] saved shard {shard_path.name}: {len(current_shard)} batches, ' f'~{current_size / 1024**3:.2f}GB') shard_idx += 1 current_shard = [] current_size = 0 current_shard.append(batch) current_size += batch_size_bytes if current_shard: shard_path = batches_cache_dir / f'shard_{shard_idx:04d}.pt' torch.save(current_shard, shard_path) print(f'[INFO] saved shard {shard_path.name}: {len(current_shard)} batches, ' f'~{current_size / 1024**3:.2f}GB') shard_idx += 1 print(f'[INFO] saved {len(all_batches)} batches to {shard_idx} shards in {batches_cache_dir}') print('[INFO] data loading done') # ----- 加载模型 ----- model, dev = load_model(ckpt_path=args.ckpt) # ----- 推理 ----- print('*' * 20 + ' start inference ' + '*' * 20) all_logids = [] all_probs = [] time_sum = 0.0 with torch.inference_mode(): for batch in tqdm(all_batches, desc="Inference"): batch = move_batch_to_device(batch, dev) pred_mask = batch["pred_mask"].bool() t_start = time.time() logits, moe_loss = model(batch) logits = logits.squeeze(-1) probs = torch.sigmoid(logits) time_sum += time.time() - t_start masked_logids = batch["logid"][pred_mask].cpu().tolist() masked_probs = probs[pred_mask].cpu().tolist() all_logids.extend(masked_logids) all_probs.extend(masked_probs) print(f'[INFO] inference time: {round(time_sum, 4)}s') print('*' * 20 + ' end inference ' + '*' * 20) # ----- 按 test.csv 顺序写预测文件 ----- logid_to_prob = dict(zip(all_logids, all_probs)) test_logids_in_order = [] with open(input_file, 'r') as f: for line in f: line = line.strip() if line: test_logids_in_order.append(int(line.split(',')[0])) output_file.parent.mkdir(parents=True, exist_ok=True) with open(output_file, 'w') as f: for logid in test_logids_in_order: f.write(f"{logid_to_prob[logid]}\n") print(f'[INFO] predictions written to {output_file}, total: {len(test_logids_in_order)}') # ----- 打分 ----- if label_file.exists(): result = _cal_score(output_file, label_file, default_latency=time_sum) print(f'[INFO] AUC: {result["auc"]:.6f}') print(f'[INFO] PCOC: {result["pcoc"]:.6f}') print(f'[INFO] Latency: {result["latency"]:.4f}s') print(f'[INFO] score_latency: {result["score_latency"]:.6f}') print(f'[INFO] score_model: {result["score_model"]:.6f}') print(f'[INFO] score_all: {result["score_all"]:.6f}') return result else: print(f'[WARNING] label file {label_file} not found, skipping scoring') return None if __name__ == '__main__': main()