import sys import os # 获取当前环境脚本所在目录或指定绝对路径 if os.path.exists("../libraries"): lib_path = os.path.abspath("../libraries") sys.path.append(lib_path) import math import argparse from pathlib import Path from collections import defaultdict import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from tqdm import tqdm # ============================================================ # 数据加载(来自 train/dataset.py) # ============================================================ def _detect_has_clk(file_path): """检测 CSV 文件是否包含 clk 列(5列 vs 4列格式)。 5列格式: logid,userid,adid,clk,timestamp,sign:slot... 4列格式: logid,userid,adid,timestamp,sign:slot... 通过第5个字段是否包含 ':' 来判断:有 ':' 说明已经是 sign:slot,即无 clk 列。 """ with open(file_path, 'r') as f: for line in f: line = line.strip() if not line: continue parts = line.split(',') if len(parts) >= 5: return ':' not in parts[4] return False return False def load_sample_files(sample_files_list): """加载 CSV sample 文件,返回 item_dict 和 user_seq。 自动检测每个文件是 5列(含clk)还是 4列(无clk)格式。 """ sample_files = sorted([Path(f) for f in sample_files_list]) print(f'[INFO] loading {len(sample_files)} files: {[str(f) for f in sample_files]}') item_dict = {} user_logs = defaultdict(list) for sample_file in tqdm(sample_files, desc='Loading sample files'): has_clk = _detect_has_clk(sample_file) min_parts = 5 if has_clk else 4 print(f' {sample_file.name}: has_clk={has_clk}') with open(sample_file, 'r') as f: for line in f: line = line.strip() if not line: continue parts = line.split(',') if len(parts) < min_parts: continue logid = int(parts[0]) userid = int(parts[1]) adid = int(parts[2]) if has_clk: clk = int(parts[3]) timestamp = int(parts[4]) feat_start = 5 else: clk = 0 timestamp = int(parts[3]) feat_start = 4 signs = [] slots = [] for pair in parts[feat_start:]: if ':' in pair: s, sl = pair.split(':', 1) signs.append(int(s)) slots.append(int(sl)) item_dict[logid] = { 'logid': logid, 'userid': userid, 'adid': adid, 'clk': clk, 'timestamp': timestamp, 'signs': np.array(signs, dtype=np.int64), 'slots': np.array(slots, dtype=np.int64), } user_logs[userid].append((timestamp, logid)) user_seq = {} for userid, logs in user_logs.items(): logs.sort(key=lambda x: x[0]) user_seq[userid] = [logid for _, logid in logs] print(f'[INFO] loaded {len(item_dict)} records, {len(user_seq)} users') return item_dict, user_seq def load_logids_from_file(file_path): """快速读取一个 sample 文件中的所有 logid""" logids = set() with open(file_path, 'r') as f: for line in f: line = line.strip() if not line: continue comma = line.index(',') logids.add(int(line[:comma])) return logids class CTRTestSeqDataset(Dataset): """按用户组织的 CTR 测试数据集(对齐评测接口)""" def __init__(self, test_logids_ordered, item_dict, user_seq=None, max_feasign_per_slot=None, max_ctx_len=None): super().__init__() self.item_dict = item_dict self.user_seq = user_seq if user_seq else {} self.max_feasign_per_slot = max_feasign_per_slot self.max_ctx_len = max_ctx_len self.pred_logids = set(test_logids_ordered) if test_logids_ordered else set() self.user_items = defaultdict(list) for logid, rec in item_dict.items(): userid = rec['userid'] feasign = defaultdict(list) for slot, sign in zip(rec['slots'].tolist(), rec['signs'].tolist()): feasign[slot].append(sign) if max_feasign_per_slot is not None: feasign = {slot: signs[:max_feasign_per_slot[slot]] if max_feasign_per_slot.get(slot, -1) != -1 else signs for slot, signs in feasign.items()} feasign = dict(feasign) label = rec['clk'] self.user_items[userid].append((logid, feasign, label)) self.user_ids = sorted(self.user_items.keys()) self.num_users = len(self.user_ids) self.total_samples = len(item_dict) all_signs = set() for rec in item_dict.values(): all_signs.update(rec['signs'].tolist()) self.max_slot_id = 28 self.max_sign_id = max(all_signs) if all_signs else 0 def __len__(self): return self.num_users def __getitem__(self, index): userid = self.user_ids[index] items = self.user_items[userid] if self.user_seq and userid in self.user_seq: seq_order = {logid: i for i, logid in enumerate(self.user_seq[userid])} items.sort(key=lambda x: seq_order.get(x[0], x[0])) else: items.sort(key=lambda x: x[0]) feasigns = [] labels = [] logids = [] for logid, feasign, label in items: logids.append(logid) feasigns.append(feasign) labels.append(label) return { 'userid': userid, 'logids': logids, 'feasigns': feasigns, 'labels': labels, 'pred_mask': [1 if logid in self.pred_logids else 0 for logid in logids], } def make_collate_fn(max_slot_id): def collate_user_batch(batch): all_userids = [] all_logids = [] all_labels = [] all_pred_masks = [] all_feasigns = [] user_offsets = [0] for item in batch: for i, logid in enumerate(item['logids']): all_userids.append(item['userid']) all_logids.append(logid) all_labels.append(item['labels'][i]) all_pred_masks.append(item['pred_mask'][i]) all_feasigns.append(item['feasigns'][i]) user_offsets.append(len(all_labels)) slot_data = {} for slot in range(1, max_slot_id + 1): values = [] offsets = [0] for feasign in all_feasigns: if slot in feasign: values.extend(feasign[slot]) offsets.append(len(values)) slot_data[slot] = ( torch.tensor(values, dtype=torch.long), torch.tensor(offsets, dtype=torch.long), ) result = { 'userid': torch.tensor(all_userids, dtype=torch.long), 'logid': torch.tensor(all_logids, dtype=torch.long), 'label': torch.tensor(all_labels, dtype=torch.float32), 'pred_mask': torch.tensor(all_pred_masks, dtype=torch.bool), 'user_offsets': torch.tensor(user_offsets, dtype=torch.long), } result.update(slot_data) return result return collate_user_batch # ============================================================ # 模型定义(来自 main.py) # ============================================================ def move_batch_to_device(batch, device): if isinstance(batch, dict): return {k: move_batch_to_device(v, device) for k, v in batch.items()} elif isinstance(batch, (list, tuple)): return [move_batch_to_device(x, device) for x in batch] elif torch.is_tensor(batch): x = batch.to(device) # 浮点 tensor → FP16,整数 tensor 保持不变 if x.dtype == torch.float32: x = x.half() return x else: return batch class RepEncoder(nn.Module): def __init__(self, vocab_size, emb_dim, padding_idx=0, slot_num=0, d_model=0): super().__init__() self.emb = nn.Embedding(num_embeddings=vocab_size, embedding_dim=emb_dim, padding_idx=padding_idx) self.emb_dim = emb_dim self.slot_num = slot_num self.input_norm = nn.LayerNorm(slot_num * emb_dim) self.linear = nn.Linear(in_features=slot_num * emb_dim, out_features=d_model) def forward(self, batch): pooled_embs = [] max_idx = self.emb.num_embeddings - 1 target_dtype = self.input_norm.weight.dtype # 后续层 dtype(FP16 时为 torch.float16) for i in range(self.slot_num): values, offsets = batch[i + 1] offsets = offsets.to(values.device) values = values.clamp(0, max_idx) # 超出 vocab_size 的 sign id 截断,避免越界 sign_emb = self.emb(values).to(target_dtype) res = torch.segment_reduce(sign_emb, reduce='sum', offsets=offsets, initial=0) pooled_embs.append(res) fused_embs = torch.cat(pooled_embs, dim=1) norm_emb = self.input_norm(fused_embs) rep_emb = self.linear(norm_emb) return rep_emb def scaled_dot_product(q, k, v, extension): """使用 PyTorch SDPA 后端(自动启用 Flash Attention / Memory Efficient Attention)""" if extension is not None and "mask" in extension: attn_mask = extension["mask"].to(device=q.device) else: attn_mask = None return F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False, ) class Expert(nn.Module): def __init__(self, d_model, dim_ff): super().__init__() self.fc1 = nn.Linear(d_model, dim_ff) self.fc2 = nn.Linear(dim_ff, d_model) def forward(self, x): return self.fc2(F.relu(self.fc1(x))) class TopKGate(nn.Module): def __init__(self, d_model, num_experts, k=2, noisy_gating=True): super().__init__() self.w_g = nn.Linear(d_model, num_experts) self.num_experts = num_experts self.k = k self.noisy_gating = noisy_gating def forward(self, x): # x: [B,S,D] logits = self.w_g(x) # [B,S,E] if self.noisy_gating and self.training: logits = logits + torch.randn_like(logits) * 0.1 probs = torch.softmax(logits, dim=-1) # [B,S,E] topk_score, topk_idx = torch.topk(probs, self.k, dim=-1) # [B,S,k] return topk_idx, topk_score, probs class SMoE(nn.Module): def __init__(self, d_model, dim_ff, num_experts, k=2): super().__init__() self.num_experts = num_experts self.k = k self.experts = nn.ModuleList([ Expert(d_model, dim_ff) for _ in range(num_experts) ]) self.gate = TopKGate(d_model, num_experts, k=k) def forward(self, x): # x: [B,S,D] B, S, D = x.shape topk_idx, topk_score, probs = self.gate(x) out = torch.zeros_like(x) # flatten x_flat = x.reshape(-1, D) # [B*S, D] idx_flat = topk_idx.reshape(-1, self.k) # [B*S, k] score_flat = topk_score.reshape(-1, self.k) out_flat = out.reshape(-1, D) # 提前 reshape,避免循环内重复 for i in range(self.num_experts): # 找到被路由到 expert i 的 token mask = (idx_flat == i) # [B*S, k] token_idx, k_idx = mask.nonzero(as_tuple=True) if token_idx.numel() == 0: continue selected_x = x_flat[token_idx] # [N, D] expert_out = self.experts[i](selected_x) # [N, D] weight = score_flat[token_idx, k_idx].unsqueeze(-1) out_flat[token_idx] += expert_out * weight importance = probs.sum(dim=(0,1)) # [E] moe_loss = (importance.std() / (importance.mean() + 1e-6)) return out, moe_loss class TransformerEncoder(nn.Module): def __init__(self, d_model, n_heads, num_layers, dim_ff, act="relu", attention_fn=scaled_dot_product): super().__init__() self.d_model = d_model self.n_heads = n_heads self.head_dim = d_model // n_heads self.num_layers = num_layers assert d_model % n_heads == 0 self.qkv_proj = nn.ModuleList([nn.Linear(d_model, 3 * d_model) for _ in range(num_layers)]) self.out_proj = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(num_layers)]) self.ffn1 = nn.ModuleList([nn.Linear(d_model, dim_ff) for _ in range(num_layers)]) self.ffn2 = nn.ModuleList([nn.Linear(dim_ff, d_model) for _ in range(num_layers)]) self.norm1 = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(num_layers)]) self.norm2 = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(num_layers)]) self.act = getattr(F, act) self.attention_fn = attention_fn self.moe = nn.ModuleList([ SMoE(d_model, dim_ff, num_experts=8, k=2) for _ in range(num_layers) ]) def forward(self, x, extension): x = x.unsqueeze(0) B, S, D = x.shape moe_loss_total = 0.0 for i in range(self.num_layers): residual = x x = self.norm1[i](x) qkv = self.qkv_proj[i](x) qkv = qkv.view(B, S, self.n_heads, 3 * self.head_dim) qkv = qkv.permute(0, 2, 1, 3) q, k, v = torch.split(qkv, self.head_dim, dim=-1) attn_out = self.attention_fn(q, k, v, extension) attn_out = attn_out.permute(0, 2, 1, 3).reshape(B, S, D) x = residual + self.out_proj[i](attn_out) residual = x x = self.norm2[i](x) moe_out, moe_loss = self.moe[i](x) x = residual + moe_out moe_loss_total = moe_loss_total + moe_loss return x, moe_loss_total class CTRModel(nn.Module): def __init__(self, rep_encoder, seq_encoder, d_model): super().__init__() self.rep_encoder = rep_encoder self.seq_encoder = seq_encoder self.d_model = d_model self.linear = nn.Linear(d_model, 1) def get_sequence_causal_mask(self, seq_info): lengths = seq_info[1:] - seq_info[:-1] lengths = lengths.view(-1) indices = torch.cumsum(torch.ones_like(lengths), dim=0) - 1 result = torch.repeat_interleave(indices, lengths) a = result.view(1, -1) - result.view(-1, 1) out_mask = torch.tril((a == 0).to(torch.int32)).bool() return out_mask def forward(self, batch): seq_input = self.rep_encoder(batch) seq_mask = self.get_sequence_causal_mask(batch["user_offsets"]) encoder_output, moe_loss = self.seq_encoder( x=seq_input, extension={"mask": seq_mask.unsqueeze(0).unsqueeze(0)}, ) encoder_output = encoder_output.squeeze(0) pred = self.linear(encoder_output) pred_logits = torch.clamp(pred, min=-15.0, max=15.0) return pred_logits, moe_loss # ============================================================ # 模型加载入口 # ============================================================ def load_model(ckpt_path, device='cuda:0'): """加载模型并返回,供 evaluation.py 调用。 Args: ckpt_path: checkpoint 文件路径(评测系统传入 Path 对象) device: 推理设备(默认 'cuda:0') Returns: (model, device) 元组 """ emb_dim = 512 slot_num = 28 vocab_size = 5000000 d_model = 512 n_heads = 8 num_layers = 8 dim_ff = 1024 rep_encoder = RepEncoder( vocab_size=vocab_size, emb_dim=emb_dim, padding_idx=0, slot_num=slot_num, d_model=d_model, ) seq_encoder = TransformerEncoder( d_model=d_model, n_heads=n_heads, num_layers=num_layers, dim_ff=dim_ff, act="relu", ) model = CTRModel(rep_encoder, seq_encoder, d_model=d_model) dev = torch.device(device if torch.cuda.is_available() else "cpu") # 加载 checkpoint # 若需要加载自定义修改的权重,请修改 479-488行逻辑,强制使用你文件夹中的权重 # 测评系统默认使用原始官方权重 if ckpt_path is None: ckpt_path = Path(__file__).parent / 'ckpt.pt' else: ckpt_path = Path(ckpt_path) if ckpt_path.exists(): ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False) model.load_state_dict(ckpt['model_state_dict']) print(f"[INFO] Loaded checkpoint from {ckpt_path} (epoch={ckpt.get('epoch', '?')})") # === FP16 量化:模型参数转半精度,Embedding 保留 FP32 === model = model.half() model.rep_encoder.emb = model.rep_encoder.emb.to(torch.float32) print("[INFO] Model converted to FP16 (embedding kept in FP32)") else: print(f"[WARNING] Checkpoint {ckpt_path} not found, using random weights") model.to(dev) # === 2:4 非结构化稀疏:仅裁剪 Expert FFN 权重,不动 attention/gate === try: sp_count = 0 for layer in model.seq_encoder.moe: for expert in layer.experts: for attr in ['fc1', 'fc2']: linear = getattr(expert, attr) w = linear.weight.data.clone() shape = w.shape # 2:4 幅度剪枝:每 4 个连续元素保留 top 2 w_flat = w.reshape(-1, 4) _, top_idx = torch.topk(w_flat.abs(), k=2, dim=1) mask = torch.zeros_like(w_flat) mask.scatter_(1, top_idx, 1.0) pruned = (w_flat * mask).reshape(shape) sparse_w = torch.sparse.to_sparse_semi_structured(pruned) bias = linear.bias linear.forward = lambda x, sw=sparse_w, b=bias: ( torch.matmul(x, sw.t()) + b if b is not None else torch.matmul(x, sw.t()) ) sp_count += 1 print(f"[INFO] 2:4 sparsity applied to {sp_count} Expert Linear layers") except Exception as e: print(f"[WARNING] 2:4 sparsity failed ({e}), keeping dense weights") model.eval() print(f"[INFO] Model ready. Device: {dev}") return model, dev # ============================================================ # 打分工具(与 evaluation.py 保持一致) # ============================================================ def _read_predict(file_path): predictions = [] with open(file_path, 'r') as f: for line in f: line = line.strip() if line: predictions.append(float(line)) import numpy as np return np.array(predictions) def _read_label(file_path): labels = [] with open(file_path, 'r') as f: for line in f: line = line.strip() if line: parts = line.split(',') if len(parts) >= 4: labels.append(float(parts[3])) else: labels.append(float(line)) import numpy as np return np.array(labels) def _cal_score(predict_file, label_file, default_latency=0.0): import numpy as np from sklearn.metrics import roc_auc_score predictions = _read_predict(predict_file) labels = _read_label(label_file) unique_labels = np.unique(labels) if len(unique_labels) < 2: print('[WARNING] only one class present in labels, AUC is not defined, returning 0.5') auc = 0.5 else: auc = roc_auc_score(labels, predictions) mean_pred = np.mean(predictions) mean_label = np.mean(labels) if mean_label == 0: pcoc = 1.0 if mean_pred == 0 else float('inf') else: pcoc = float(mean_pred / mean_label) latency = default_latency base_latency = 300 score_latency = max(0.0, (base_latency - latency) / base_latency) if latency < base_latency else 0.0 if pcoc < 0.85 or pcoc > 1.15: score_model = 0.0 else: score_model = ((auc - 0.65) * 1000 + (0.15 - abs(pcoc - 1)) / 0.15 * 10) / 360 score_all = score_latency * 70 + score_model * 30 return { 'auc': auc, 'pcoc': pcoc, 'latency': latency, 'score_latency': score_latency, 'score_model': score_model, 'score_all': score_all, } # ============================================================ # main:直接运行 infer.py 进行测试 # ============================================================ def main(): import io import time import argparse parser = argparse.ArgumentParser() parser.add_argument('--ckpt', type=str, default=None, help='checkpoint 文件路径,默认使用同目录下的 ckpt.pt') args = parser.parse_args() cur_path = Path(__file__).parent.absolute() ref_dir = cur_path / 'dataset' history_dir = ref_dir / 'history' input_file = ref_dir / 'test.csv' output_file = Path('predict.txt') label_file = ref_dir / 'label_data.txt' # ----- 数据加载,优先从缓存读取 ----- MAX_SHARD_BYTES = 2 * 1024 * 1024 * 1024 # 2GB per shard batches_cache_dir = ref_dir / 'cached_batches' if batches_cache_dir.exists() and any(batches_cache_dir.glob('shard_*.pt')): print(f'[INFO] loading cached batch shards from {batches_cache_dir}') all_batches = [] shard_files = sorted(batches_cache_dir.glob('shard_*.pt'), key=lambda p: int(p.stem.split('_')[1])) for sf in shard_files: shard_batches = torch.load(sf, weights_only=False) all_batches.extend(shard_batches) print(f'[INFO] loaded {len(shard_batches)} batches from {sf.name}') print(f'[INFO] loaded {len(all_batches)} cached batches total from {len(shard_files)} shards') else: print('[INFO] start loading data from CSV') history_files = sorted(history_dir.glob('*.csv')) if history_dir.exists() else [] all_files = history_files + [input_file] item_dict, user_seq = load_sample_files(sample_files_list=all_files) test_pred_logids = load_logids_from_file(input_file) print(f'[INFO] Test pred logids count: {len(test_pred_logids)}') max_feasign_per_slot = {1: 2} test_dataset = CTRTestSeqDataset( test_logids_ordered=list(test_pred_logids), item_dict=item_dict, user_seq=user_seq, max_feasign_per_slot=max_feasign_per_slot, ) print(f'[INFO] num_users={test_dataset.num_users}, ' f'total_samples={test_dataset.total_samples}, ' f'pred_samples={len(test_pred_logids)}, ' f'max_sign_id={test_dataset.max_sign_id}') test_loader = DataLoader( test_dataset, batch_size=50, shuffle=False, num_workers=0, collate_fn=make_collate_fn(test_dataset.max_slot_id), ) # 收集 batches,预转 FP16 后按分片缓存 print('[INFO] collecting batches (pre-converting to FP16) and saving sharded cache...') all_batches = [] for batch in test_loader: batch = move_batch_to_device(batch, torch.device('cpu')) all_batches.append(batch) batches_cache_dir.mkdir(parents=True, exist_ok=True) shard_idx = 0 current_shard = [] current_size = 0 for batch in all_batches: buf = io.BytesIO() torch.save(batch, buf) batch_size_bytes = buf.tell() if current_shard and current_size + batch_size_bytes > MAX_SHARD_BYTES: shard_path = batches_cache_dir / f'shard_{shard_idx:04d}.pt' torch.save(current_shard, shard_path) print(f'[INFO] saved shard {shard_path.name}: {len(current_shard)} batches, ' f'~{current_size / 1024**3:.2f}GB') shard_idx += 1 current_shard = [] current_size = 0 current_shard.append(batch) current_size += batch_size_bytes if current_shard: shard_path = batches_cache_dir / f'shard_{shard_idx:04d}.pt' torch.save(current_shard, shard_path) print(f'[INFO] saved shard {shard_path.name}: {len(current_shard)} batches, ' f'~{current_size / 1024**3:.2f}GB') shard_idx += 1 print(f'[INFO] saved {len(all_batches)} batches to {shard_idx} shards in {batches_cache_dir}') print('[INFO] data loading done') # ----- 加载模型 ----- model, dev = load_model(ckpt_path=args.ckpt) # ----- 推理 ----- print('*' * 20 + ' start inference ' + '*' * 20) all_logids = [] all_probs = [] time_sum = 0.0 with torch.inference_mode(): for batch in tqdm(all_batches, desc="Inference"): batch = move_batch_to_device(batch, dev) pred_mask = batch["pred_mask"].bool() t_start = time.time() logits, moe_loss = model(batch) logits = logits.squeeze(-1) probs = torch.sigmoid(logits) time_sum += time.time() - t_start masked_logids = batch["logid"][pred_mask].cpu().tolist() masked_probs = probs[pred_mask].cpu().tolist() all_logids.extend(masked_logids) all_probs.extend(masked_probs) print(f'[INFO] inference time: {round(time_sum, 4)}s') print('*' * 20 + ' end inference ' + '*' * 20) # ----- 按 test.csv 顺序写预测文件 ----- logid_to_prob = dict(zip(all_logids, all_probs)) test_logids_in_order = [] with open(input_file, 'r') as f: for line in f: line = line.strip() if line: test_logids_in_order.append(int(line.split(',')[0])) output_file.parent.mkdir(parents=True, exist_ok=True) with open(output_file, 'w') as f: for logid in test_logids_in_order: f.write(f"{logid_to_prob[logid]}\n") print(f'[INFO] predictions written to {output_file}, total: {len(test_logids_in_order)}') # ----- 打分 ----- if label_file.exists(): result = _cal_score(output_file, label_file, default_latency=time_sum) print(f'[INFO] AUC: {result["auc"]:.6f}') print(f'[INFO] PCOC: {result["pcoc"]:.6f}') print(f'[INFO] Latency: {result["latency"]:.4f}s') print(f'[INFO] score_latency: {result["score_latency"]:.6f}') print(f'[INFO] score_model: {result["score_model"]:.6f}') print(f'[INFO] score_all: {result["score_all"]:.6f}') return result else: print(f'[WARNING] label file {label_file} not found, skipping scoring') return None if __name__ == '__main__': main()