chore: 初始化 CTI 推理优化项目
- baseline infer.py + requirements.txt + build_env.sh - GRAB / HSTU 两篇核心论文 - 比赛规则和提交接口说明 - 项目 CLAUDE.md
This commit is contained in:
@@ -0,0 +1,4 @@
|
||||
#!/bin/bash
|
||||
|
||||
|
||||
echo "build env succeess"
|
||||
@@ -0,0 +1,728 @@
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 获取当前环境脚本所在目录或指定绝对路径
|
||||
if os.path.exists("../libraries"):
|
||||
lib_path = os.path.abspath("../libraries")
|
||||
sys.path.append(lib_path)
|
||||
|
||||
import math
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 数据加载(来自 train/dataset.py)
|
||||
# ============================================================
|
||||
|
||||
def _detect_has_clk(file_path):
|
||||
"""检测 CSV 文件是否包含 clk 列(5列 vs 4列格式)。
|
||||
5列格式: logid,userid,adid,clk,timestamp,sign:slot...
|
||||
4列格式: logid,userid,adid,timestamp,sign:slot...
|
||||
通过第5个字段是否包含 ':' 来判断:有 ':' 说明已经是 sign:slot,即无 clk 列。
|
||||
"""
|
||||
with open(file_path, 'r') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
parts = line.split(',')
|
||||
if len(parts) >= 5:
|
||||
return ':' not in parts[4]
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def load_sample_files(sample_files_list):
|
||||
"""加载 CSV sample 文件,返回 item_dict 和 user_seq。
|
||||
自动检测每个文件是 5列(含clk)还是 4列(无clk)格式。
|
||||
"""
|
||||
sample_files = sorted([Path(f) for f in sample_files_list])
|
||||
print(f'[INFO] loading {len(sample_files)} files: {[str(f) for f in sample_files]}')
|
||||
|
||||
item_dict = {}
|
||||
user_logs = defaultdict(list)
|
||||
|
||||
for sample_file in tqdm(sample_files, desc='Loading sample files'):
|
||||
has_clk = _detect_has_clk(sample_file)
|
||||
min_parts = 5 if has_clk else 4
|
||||
print(f' {sample_file.name}: has_clk={has_clk}')
|
||||
|
||||
with open(sample_file, 'r') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
parts = line.split(',')
|
||||
if len(parts) < min_parts:
|
||||
continue
|
||||
|
||||
logid = int(parts[0])
|
||||
userid = int(parts[1])
|
||||
adid = int(parts[2])
|
||||
|
||||
if has_clk:
|
||||
clk = int(parts[3])
|
||||
timestamp = int(parts[4])
|
||||
feat_start = 5
|
||||
else:
|
||||
clk = 0
|
||||
timestamp = int(parts[3])
|
||||
feat_start = 4
|
||||
|
||||
signs = []
|
||||
slots = []
|
||||
for pair in parts[feat_start:]:
|
||||
if ':' in pair:
|
||||
s, sl = pair.split(':', 1)
|
||||
signs.append(int(s))
|
||||
slots.append(int(sl))
|
||||
|
||||
item_dict[logid] = {
|
||||
'logid': logid,
|
||||
'userid': userid,
|
||||
'adid': adid,
|
||||
'clk': clk,
|
||||
'timestamp': timestamp,
|
||||
'signs': np.array(signs, dtype=np.int64),
|
||||
'slots': np.array(slots, dtype=np.int64),
|
||||
}
|
||||
user_logs[userid].append((timestamp, logid))
|
||||
|
||||
user_seq = {}
|
||||
for userid, logs in user_logs.items():
|
||||
logs.sort(key=lambda x: x[0])
|
||||
user_seq[userid] = [logid for _, logid in logs]
|
||||
|
||||
print(f'[INFO] loaded {len(item_dict)} records, {len(user_seq)} users')
|
||||
return item_dict, user_seq
|
||||
|
||||
|
||||
def load_logids_from_file(file_path):
|
||||
"""快速读取一个 sample 文件中的所有 logid"""
|
||||
logids = set()
|
||||
with open(file_path, 'r') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
comma = line.index(',')
|
||||
logids.add(int(line[:comma]))
|
||||
return logids
|
||||
|
||||
|
||||
class CTRUserDataset(Dataset):
|
||||
"""按用户组织的 CTR 数据集"""
|
||||
|
||||
def __init__(self, item_dict, user_seq=None, max_feasign_per_slot=None, pred_logids=None):
|
||||
super().__init__()
|
||||
self.item_dict = item_dict
|
||||
self.user_seq = user_seq if user_seq else {}
|
||||
self.max_feasign_per_slot = max_feasign_per_slot
|
||||
self.pred_logids = pred_logids if pred_logids is not None else set()
|
||||
|
||||
self.user_items = defaultdict(list)
|
||||
for logid, rec in item_dict.items():
|
||||
userid = rec['userid']
|
||||
feasign = defaultdict(list)
|
||||
for slot, sign in zip(rec['slots'].tolist(), rec['signs'].tolist()):
|
||||
feasign[slot].append(sign)
|
||||
if max_feasign_per_slot is not None:
|
||||
feasign = {slot: signs[:max_feasign_per_slot[slot]]
|
||||
if max_feasign_per_slot.get(slot, -1) != -1 else signs
|
||||
for slot, signs in feasign.items()}
|
||||
feasign = dict(feasign)
|
||||
label = rec['clk']
|
||||
self.user_items[userid].append((logid, feasign, label))
|
||||
|
||||
self.user_ids = sorted(self.user_items.keys())
|
||||
self.num_users = len(self.user_ids)
|
||||
self.total_samples = len(item_dict)
|
||||
|
||||
all_signs = set()
|
||||
for rec in item_dict.values():
|
||||
all_signs.update(rec['signs'].tolist())
|
||||
self.max_slot_id = 28
|
||||
self.max_sign_id = max(all_signs) if all_signs else 0
|
||||
|
||||
def __len__(self):
|
||||
return self.num_users
|
||||
|
||||
def __getitem__(self, index):
|
||||
userid = self.user_ids[index]
|
||||
items = self.user_items[userid]
|
||||
|
||||
if self.user_seq and userid in self.user_seq:
|
||||
seq_order = {logid: i for i, logid in enumerate(self.user_seq[userid])}
|
||||
items.sort(key=lambda x: seq_order.get(x[0], x[0]))
|
||||
else:
|
||||
items.sort(key=lambda x: x[0])
|
||||
|
||||
feasigns = []
|
||||
labels = []
|
||||
logids = []
|
||||
for logid, feasign, label in items:
|
||||
logids.append(logid)
|
||||
feasigns.append(feasign)
|
||||
labels.append(label)
|
||||
|
||||
return {
|
||||
'userid': userid,
|
||||
'logids': logids,
|
||||
'feasigns': feasigns,
|
||||
'labels': labels,
|
||||
'pred_mask': [1 if logid in self.pred_logids else 0 for logid in logids],
|
||||
}
|
||||
|
||||
|
||||
def make_collate_fn(max_slot_id):
|
||||
def collate_user_batch(batch):
|
||||
all_userids = []
|
||||
all_logids = []
|
||||
all_labels = []
|
||||
all_pred_masks = []
|
||||
all_feasigns = []
|
||||
user_offsets = [0]
|
||||
|
||||
for item in batch:
|
||||
for i, logid in enumerate(item['logids']):
|
||||
all_userids.append(item['userid'])
|
||||
all_logids.append(logid)
|
||||
all_labels.append(item['labels'][i])
|
||||
all_pred_masks.append(item['pred_mask'][i])
|
||||
all_feasigns.append(item['feasigns'][i])
|
||||
user_offsets.append(len(all_labels))
|
||||
|
||||
slot_data = {}
|
||||
for slot in range(1, max_slot_id + 1):
|
||||
values = []
|
||||
offsets = [0]
|
||||
for feasign in all_feasigns:
|
||||
if slot in feasign:
|
||||
values.extend(feasign[slot])
|
||||
offsets.append(len(values))
|
||||
slot_data[slot] = (
|
||||
torch.tensor(values, dtype=torch.long),
|
||||
torch.tensor(offsets, dtype=torch.long),
|
||||
)
|
||||
|
||||
result = {
|
||||
'userid': torch.tensor(all_userids, dtype=torch.long),
|
||||
'logid': torch.tensor(all_logids, dtype=torch.long),
|
||||
'label': torch.tensor(all_labels, dtype=torch.float32),
|
||||
'pred_mask': torch.tensor(all_pred_masks, dtype=torch.bool),
|
||||
'user_offsets': torch.tensor(user_offsets, dtype=torch.long),
|
||||
}
|
||||
result.update(slot_data)
|
||||
return result
|
||||
|
||||
return collate_user_batch
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 模型定义(来自 main.py)
|
||||
# ============================================================
|
||||
|
||||
def move_batch_to_device(batch, device):
|
||||
if isinstance(batch, dict):
|
||||
return {k: move_batch_to_device(v, device) for k, v in batch.items()}
|
||||
elif isinstance(batch, (list, tuple)):
|
||||
return [move_batch_to_device(x, device) for x in batch]
|
||||
elif torch.is_tensor(batch):
|
||||
return batch.to(device)
|
||||
else:
|
||||
return batch
|
||||
|
||||
|
||||
class RepEncoder(nn.Module):
|
||||
def __init__(self, vocab_size, emb_dim, padding_idx=0, slot_num=0, d_model=0):
|
||||
super().__init__()
|
||||
self.emb = nn.Embedding(num_embeddings=vocab_size, embedding_dim=emb_dim, padding_idx=padding_idx)
|
||||
self.emb_dim = emb_dim
|
||||
self.slot_num = slot_num
|
||||
self.input_norm = nn.LayerNorm(slot_num * emb_dim)
|
||||
self.linear = nn.Linear(in_features=slot_num * emb_dim, out_features=d_model)
|
||||
|
||||
def forward(self, batch):
|
||||
pooled_embs = []
|
||||
max_idx = self.emb.num_embeddings - 1
|
||||
for i in range(self.slot_num):
|
||||
values, offsets = batch[i + 1]
|
||||
offsets = offsets.to(values.device)
|
||||
values = values.clamp(0, max_idx) # 超出 vocab_size 的 sign id 截断,避免越界
|
||||
sign_emb = self.emb(values)
|
||||
res = torch.segment_reduce(sign_emb, reduce='sum', offsets=offsets, initial=0)
|
||||
pooled_embs.append(res)
|
||||
fused_embs = torch.cat(pooled_embs, dim=1)
|
||||
norm_emb = self.input_norm(fused_embs)
|
||||
rep_emb = self.linear(norm_emb)
|
||||
return rep_emb
|
||||
|
||||
|
||||
def scaled_dot_product(q, k, v, extension):
|
||||
d = q.size(-1)
|
||||
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d)
|
||||
if extension is not None and "mask" in extension:
|
||||
mask = extension["mask"]
|
||||
scores = scores.masked_fill(mask == 0, float("-inf"))
|
||||
attn = torch.softmax(scores, dim=-1)
|
||||
out = torch.matmul(attn, v)
|
||||
return out
|
||||
|
||||
|
||||
class Expert(nn.Module):
|
||||
def __init__(self, d_model, dim_ff):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(d_model, dim_ff)
|
||||
self.fc2 = nn.Linear(dim_ff, d_model)
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc2(F.relu(self.fc1(x)))
|
||||
|
||||
|
||||
class TopKGate(nn.Module):
|
||||
def __init__(self, d_model, num_experts, k=2, noisy_gating=True):
|
||||
super().__init__()
|
||||
self.w_g = nn.Linear(d_model, num_experts)
|
||||
self.num_experts = num_experts
|
||||
self.k = k
|
||||
self.noisy_gating = noisy_gating
|
||||
|
||||
def forward(self, x):
|
||||
# x: [B,S,D]
|
||||
logits = self.w_g(x) # [B,S,E]
|
||||
|
||||
if self.noisy_gating and self.training:
|
||||
logits = logits + torch.randn_like(logits) * 0.1
|
||||
|
||||
probs = torch.softmax(logits, dim=-1) # [B,S,E]
|
||||
|
||||
topk_score, topk_idx = torch.topk(probs, self.k, dim=-1) # [B,S,k]
|
||||
|
||||
return topk_idx, topk_score, probs
|
||||
|
||||
class SMoE(nn.Module):
|
||||
def __init__(self, d_model, dim_ff, num_experts, k=2):
|
||||
super().__init__()
|
||||
self.num_experts = num_experts
|
||||
self.k = k
|
||||
|
||||
self.experts = nn.ModuleList([
|
||||
Expert(d_model, dim_ff) for _ in range(num_experts)
|
||||
])
|
||||
|
||||
self.gate = TopKGate(d_model, num_experts, k=k)
|
||||
|
||||
def forward(self, x):
|
||||
# x: [B,S,D]
|
||||
B, S, D = x.shape
|
||||
|
||||
topk_idx, topk_score, probs = self.gate(x)
|
||||
|
||||
out = torch.zeros_like(x)
|
||||
|
||||
# flatten
|
||||
x_flat = x.reshape(-1, D) # [B*S, D]
|
||||
idx_flat = topk_idx.reshape(-1, self.k) # [B*S, k]
|
||||
score_flat = topk_score.reshape(-1, self.k)
|
||||
|
||||
for i in range(self.num_experts):
|
||||
# 找到被路由到 expert i 的 token
|
||||
mask = (idx_flat == i) # [B*S, k]
|
||||
|
||||
if not mask.any():
|
||||
continue
|
||||
|
||||
# 哪些 token 命中了 expert i
|
||||
token_idx, k_idx = mask.nonzero(as_tuple=True)
|
||||
|
||||
selected_x = x_flat[token_idx] # [N, D]
|
||||
|
||||
expert_out = self.experts[i](selected_x) # [N, D]
|
||||
|
||||
weight = score_flat[token_idx, k_idx].unsqueeze(-1)
|
||||
|
||||
out_flat = out.reshape(-1, D)
|
||||
out_flat[token_idx] += expert_out * weight
|
||||
|
||||
importance = probs.sum(dim=(0,1)) # [E]
|
||||
moe_loss = (importance.std() / (importance.mean() + 1e-6))
|
||||
|
||||
return out, moe_loss
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
def __init__(self, d_model, n_heads, num_layers, dim_ff, act="relu",
|
||||
attention_fn=scaled_dot_product):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.n_heads = n_heads
|
||||
self.head_dim = d_model // n_heads
|
||||
self.num_layers = num_layers
|
||||
assert d_model % n_heads == 0
|
||||
|
||||
self.qkv_proj = nn.ModuleList([nn.Linear(d_model, 3 * d_model) for _ in range(num_layers)])
|
||||
self.out_proj = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(num_layers)])
|
||||
self.ffn1 = nn.ModuleList([nn.Linear(d_model, dim_ff) for _ in range(num_layers)])
|
||||
self.ffn2 = nn.ModuleList([nn.Linear(dim_ff, d_model) for _ in range(num_layers)])
|
||||
self.norm1 = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(num_layers)])
|
||||
self.norm2 = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(num_layers)])
|
||||
self.act = getattr(F, act)
|
||||
self.attention_fn = attention_fn
|
||||
self.moe = nn.ModuleList([
|
||||
SMoE(d_model, dim_ff, num_experts=8, k=2)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
def forward(self, x, extension):
|
||||
x = x.unsqueeze(0)
|
||||
B, S, D = x.shape
|
||||
|
||||
moe_loss_total = 0.0
|
||||
for i in range(self.num_layers):
|
||||
residual = x
|
||||
x = self.norm1[i](x)
|
||||
qkv = self.qkv_proj[i](x)
|
||||
qkv = qkv.view(B, S, self.n_heads, 3 * self.head_dim)
|
||||
qkv = qkv.permute(0, 2, 1, 3)
|
||||
q, k, v = torch.split(qkv, self.head_dim, dim=-1)
|
||||
attn_out = self.attention_fn(q, k, v, extension)
|
||||
attn_out = attn_out.permute(0, 2, 1, 3).reshape(B, S, D)
|
||||
x = residual + self.out_proj[i](attn_out)
|
||||
residual = x
|
||||
x = self.norm2[i](x)
|
||||
|
||||
moe_out, moe_loss = self.moe[i](x)
|
||||
|
||||
x = residual + moe_out
|
||||
|
||||
moe_loss_total = moe_loss_total + moe_loss
|
||||
|
||||
return x, moe_loss_total
|
||||
|
||||
|
||||
class CTRModel(nn.Module):
|
||||
def __init__(self, rep_encoder, seq_encoder, d_model):
|
||||
super().__init__()
|
||||
self.rep_encoder = rep_encoder
|
||||
self.seq_encoder = seq_encoder
|
||||
self.d_model = d_model
|
||||
self.linear = nn.Linear(d_model, 1)
|
||||
|
||||
def get_sequence_causal_mask(self, seq_info):
|
||||
lengths = seq_info[1:] - seq_info[:-1]
|
||||
lengths = lengths.view(-1)
|
||||
indices = torch.cumsum(torch.ones_like(lengths), dim=0) - 1
|
||||
result = torch.repeat_interleave(indices, lengths)
|
||||
a = result.view(1, -1) - result.view(-1, 1)
|
||||
out_mask = torch.tril((a == 0).to(torch.int32)).bool()
|
||||
return out_mask
|
||||
|
||||
def forward(self, batch):
|
||||
seq_input = self.rep_encoder(batch)
|
||||
seq_mask = self.get_sequence_causal_mask(batch["user_offsets"])
|
||||
encoder_output, moe_loss = self.seq_encoder(
|
||||
x=seq_input,
|
||||
extension={"mask": seq_mask.unsqueeze(0).unsqueeze(0)},
|
||||
)
|
||||
encoder_output_dim = encoder_output.shape[-1]
|
||||
encoder_output = encoder_output.reshape(1, -1, encoder_output_dim).squeeze(0)
|
||||
pred = self.linear(encoder_output)
|
||||
pred_logits = torch.clamp(pred, min=-15.0, max=15.0)
|
||||
return pred_logits, moe_loss
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 模型加载入口
|
||||
# ============================================================
|
||||
|
||||
def load_model(device='cuda:0', ckpt_path=None):
|
||||
"""加载模型并返回,供 evaluation.py 调用。
|
||||
|
||||
Args:
|
||||
device: 推理设备(默认 'cuda:0')
|
||||
ckpt_path: checkpoint 文件路径,默认使用 infer.py 同目录下的 ckpt.pt
|
||||
|
||||
Returns:
|
||||
(model, device) 元组
|
||||
"""
|
||||
emb_dim = 512
|
||||
slot_num = 28
|
||||
vocab_size = 5000000
|
||||
d_model = 512
|
||||
n_heads = 8
|
||||
num_layers = 8
|
||||
dim_ff = 1024
|
||||
|
||||
rep_encoder = RepEncoder(
|
||||
vocab_size=vocab_size,
|
||||
emb_dim=emb_dim,
|
||||
padding_idx=0,
|
||||
slot_num=slot_num,
|
||||
d_model=d_model,
|
||||
)
|
||||
seq_encoder = TransformerEncoder(
|
||||
d_model=d_model,
|
||||
n_heads=n_heads,
|
||||
num_layers=num_layers,
|
||||
dim_ff=dim_ff,
|
||||
act="relu",
|
||||
)
|
||||
model = CTRModel(rep_encoder, seq_encoder, d_model=d_model)
|
||||
|
||||
dev = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# 加载 checkpoint
|
||||
# 若需要加载自定义修改的权重,请修改 479-488行逻辑,强制使用你文件夹中的权重
|
||||
# 测评系统默认使用原始官方权重
|
||||
if ckpt_path is None:
|
||||
ckpt_path = Path(__file__).parent / 'ckpt.pt'
|
||||
else:
|
||||
ckpt_path = Path(ckpt_path)
|
||||
if ckpt_path.exists():
|
||||
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
|
||||
model.load_state_dict(ckpt['model_state_dict'])
|
||||
print(f"[INFO] Loaded checkpoint from {ckpt_path} (epoch={ckpt.get('epoch', '?')})")
|
||||
else:
|
||||
print(f"[WARNING] Checkpoint {ckpt_path} not found, using random weights")
|
||||
|
||||
model.to(dev)
|
||||
model.eval()
|
||||
print(f"[INFO] Model ready. Device: {dev}")
|
||||
|
||||
return model, dev
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 打分工具(与 evaluation.py 保持一致)
|
||||
# ============================================================
|
||||
|
||||
def _read_predict(file_path):
|
||||
predictions = []
|
||||
with open(file_path, 'r') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
predictions.append(float(line))
|
||||
import numpy as np
|
||||
return np.array(predictions)
|
||||
|
||||
|
||||
def _read_label(file_path):
|
||||
labels = []
|
||||
with open(file_path, 'r') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
parts = line.split(',')
|
||||
if len(parts) >= 4:
|
||||
labels.append(float(parts[3]))
|
||||
else:
|
||||
labels.append(float(line))
|
||||
import numpy as np
|
||||
return np.array(labels)
|
||||
|
||||
|
||||
def _cal_score(predict_file, label_file, default_latency=0.0):
|
||||
import numpy as np
|
||||
from sklearn.metrics import roc_auc_score
|
||||
|
||||
predictions = _read_predict(predict_file)
|
||||
labels = _read_label(label_file)
|
||||
|
||||
unique_labels = np.unique(labels)
|
||||
if len(unique_labels) < 2:
|
||||
print('[WARNING] only one class present in labels, AUC is not defined, returning 0.5')
|
||||
auc = 0.5
|
||||
else:
|
||||
auc = roc_auc_score(labels, predictions)
|
||||
|
||||
mean_pred = np.mean(predictions)
|
||||
mean_label = np.mean(labels)
|
||||
if mean_label == 0:
|
||||
pcoc = 1.0 if mean_pred == 0 else float('inf')
|
||||
else:
|
||||
pcoc = float(mean_pred / mean_label)
|
||||
|
||||
latency = default_latency
|
||||
base_latency = 300
|
||||
score_latency = max(0.0, (base_latency - latency) / base_latency) if latency < base_latency else 0.0
|
||||
|
||||
if pcoc < 0.85 or pcoc > 1.15:
|
||||
score_model = 0.0
|
||||
else:
|
||||
score_model = ((auc - 0.65) * 1000 + (0.15 - abs(pcoc - 1)) / 0.15 * 10) / 360
|
||||
|
||||
score_all = score_latency * 70 + score_model * 30
|
||||
|
||||
return {
|
||||
'auc': auc,
|
||||
'pcoc': pcoc,
|
||||
'latency': latency,
|
||||
'score_latency': score_latency,
|
||||
'score_model': score_model,
|
||||
'score_all': score_all,
|
||||
}
|
||||
|
||||
|
||||
# ============================================================
|
||||
# main:直接运行 infer.py 进行测试
|
||||
# ============================================================
|
||||
|
||||
def main():
|
||||
import io
|
||||
import time
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--ckpt', type=str, default=None, help='checkpoint 文件路径,默认使用同目录下的 ckpt.pt')
|
||||
args = parser.parse_args()
|
||||
|
||||
cur_path = Path(__file__).parent.absolute()
|
||||
ref_dir = cur_path / 'dataset'
|
||||
history_dir = ref_dir / 'history'
|
||||
input_file = ref_dir / 'test.csv'
|
||||
output_file = Path('predict.txt')
|
||||
label_file = ref_dir / 'label_data.txt'
|
||||
|
||||
# ----- 数据加载,优先从缓存读取 -----
|
||||
MAX_SHARD_BYTES = 2 * 1024 * 1024 * 1024 # 2GB per shard
|
||||
batches_cache_dir = ref_dir / 'cached_batches'
|
||||
|
||||
if batches_cache_dir.exists() and any(batches_cache_dir.glob('shard_*.pt')):
|
||||
print(f'[INFO] loading cached batch shards from {batches_cache_dir}')
|
||||
all_batches = []
|
||||
shard_files = sorted(batches_cache_dir.glob('shard_*.pt'),
|
||||
key=lambda p: int(p.stem.split('_')[1]))
|
||||
for sf in shard_files:
|
||||
shard_batches = torch.load(sf, weights_only=False)
|
||||
all_batches.extend(shard_batches)
|
||||
print(f'[INFO] loaded {len(shard_batches)} batches from {sf.name}')
|
||||
print(f'[INFO] loaded {len(all_batches)} cached batches total from {len(shard_files)} shards')
|
||||
else:
|
||||
print('[INFO] start loading data from CSV')
|
||||
history_files = sorted(history_dir.glob('*.csv')) if history_dir.exists() else []
|
||||
all_files = history_files + [input_file]
|
||||
|
||||
item_dict, user_seq = load_sample_files(sample_files_list=all_files)
|
||||
test_pred_logids = load_logids_from_file(input_file)
|
||||
print(f'[INFO] Test pred logids count: {len(test_pred_logids)}')
|
||||
|
||||
max_feasign_per_slot = {1: 2}
|
||||
test_dataset = CTRUserDataset(
|
||||
item_dict, user_seq,
|
||||
max_feasign_per_slot=max_feasign_per_slot,
|
||||
pred_logids=test_pred_logids,
|
||||
)
|
||||
print(f'[INFO] num_users={test_dataset.num_users}, '
|
||||
f'total_samples={test_dataset.total_samples}, '
|
||||
f'pred_samples={len(test_pred_logids)}, '
|
||||
f'max_sign_id={test_dataset.max_sign_id}')
|
||||
|
||||
test_loader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=50,
|
||||
shuffle=False,
|
||||
num_workers=0,
|
||||
collate_fn=make_collate_fn(test_dataset.max_slot_id),
|
||||
)
|
||||
|
||||
# 收集 batches 并按分片缓存
|
||||
print('[INFO] collecting batches and saving sharded cache...')
|
||||
all_batches = [batch for batch in test_loader]
|
||||
|
||||
batches_cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
shard_idx = 0
|
||||
current_shard = []
|
||||
current_size = 0
|
||||
for batch in all_batches:
|
||||
buf = io.BytesIO()
|
||||
torch.save(batch, buf)
|
||||
batch_size_bytes = buf.tell()
|
||||
if current_shard and current_size + batch_size_bytes > MAX_SHARD_BYTES:
|
||||
shard_path = batches_cache_dir / f'shard_{shard_idx:04d}.pt'
|
||||
torch.save(current_shard, shard_path)
|
||||
print(f'[INFO] saved shard {shard_path.name}: {len(current_shard)} batches, '
|
||||
f'~{current_size / 1024**3:.2f}GB')
|
||||
shard_idx += 1
|
||||
current_shard = []
|
||||
current_size = 0
|
||||
current_shard.append(batch)
|
||||
current_size += batch_size_bytes
|
||||
if current_shard:
|
||||
shard_path = batches_cache_dir / f'shard_{shard_idx:04d}.pt'
|
||||
torch.save(current_shard, shard_path)
|
||||
print(f'[INFO] saved shard {shard_path.name}: {len(current_shard)} batches, '
|
||||
f'~{current_size / 1024**3:.2f}GB')
|
||||
shard_idx += 1
|
||||
print(f'[INFO] saved {len(all_batches)} batches to {shard_idx} shards in {batches_cache_dir}')
|
||||
|
||||
print('[INFO] data loading done')
|
||||
|
||||
# ----- 加载模型 -----
|
||||
model, dev = load_model(ckpt_path=args.ckpt)
|
||||
|
||||
# ----- 推理 -----
|
||||
print('*' * 20 + ' start inference ' + '*' * 20)
|
||||
all_logids = []
|
||||
all_probs = []
|
||||
time_sum = 0.0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(all_batches, desc="Inference"):
|
||||
batch = move_batch_to_device(batch, dev)
|
||||
pred_mask = batch["pred_mask"].bool()
|
||||
|
||||
t_start = time.time()
|
||||
logits, moe_loss = model(batch)
|
||||
logits = logits.squeeze(-1)
|
||||
probs = torch.sigmoid(logits)
|
||||
time_sum += time.time() - t_start
|
||||
|
||||
masked_logids = batch["logid"][pred_mask].cpu().tolist()
|
||||
masked_probs = probs[pred_mask].cpu().tolist()
|
||||
all_logids.extend(masked_logids)
|
||||
all_probs.extend(masked_probs)
|
||||
|
||||
print(f'[INFO] inference time: {round(time_sum, 4)}s')
|
||||
print('*' * 20 + ' end inference ' + '*' * 20)
|
||||
|
||||
# ----- 按 test.csv 顺序写预测文件 -----
|
||||
logid_to_prob = dict(zip(all_logids, all_probs))
|
||||
test_logids_in_order = []
|
||||
with open(input_file, 'r') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
test_logids_in_order.append(int(line.split(',')[0]))
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_file, 'w') as f:
|
||||
for logid in test_logids_in_order:
|
||||
f.write(f"{logid_to_prob[logid]}\n")
|
||||
print(f'[INFO] predictions written to {output_file}, total: {len(test_logids_in_order)}')
|
||||
|
||||
# ----- 打分 -----
|
||||
if label_file.exists():
|
||||
result = _cal_score(output_file, label_file, default_latency=time_sum)
|
||||
print(f'[INFO] AUC: {result["auc"]:.6f}')
|
||||
print(f'[INFO] PCOC: {result["pcoc"]:.6f}')
|
||||
print(f'[INFO] Latency: {result["latency"]:.4f}s')
|
||||
print(f'[INFO] score_latency: {result["score_latency"]:.6f}')
|
||||
print(f'[INFO] score_model: {result["score_model"]:.6f}')
|
||||
print(f'[INFO] score_all: {result["score_all"]:.6f}')
|
||||
return result
|
||||
else:
|
||||
print(f'[WARNING] label file {label_file} not found, skipping scoring')
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
filelock==3.25.2
|
||||
fsspec==2026.2.0
|
||||
Jinja2==3.1.6
|
||||
joblib==1.5.3
|
||||
MarkupSafe==3.0.3
|
||||
mpmath==1.3.0
|
||||
networkx==3.4.2
|
||||
numpy==2.2.6
|
||||
nvidia-cublas-cu12==12.4.5.8
|
||||
nvidia-cuda-cupti-cu12==12.4.127
|
||||
nvidia-cuda-nvrtc-cu12==12.4.127
|
||||
nvidia-cuda-runtime-cu12==12.4.127
|
||||
nvidia-cudnn-cu12==9.1.0.70
|
||||
nvidia-cufft-cu12==11.2.1.3
|
||||
nvidia-curand-cu12==10.3.5.147
|
||||
nvidia-cusolver-cu12==11.6.1.9
|
||||
nvidia-cusparse-cu12==12.3.1.170
|
||||
nvidia-cusparselt-cu12==0.6.2
|
||||
nvidia-nccl-cu12==2.21.5
|
||||
nvidia-nvjitlink-cu12==12.4.127
|
||||
nvidia-nvtx-cu12==12.4.127
|
||||
scikit-learn==1.7.2
|
||||
scipy==1.15.3
|
||||
sympy==1.13.1
|
||||
threadpoolctl==3.6.0
|
||||
torch==2.6.0
|
||||
tqdm==4.67.3
|
||||
triton==3.2.0
|
||||
typing_extensions==4.15.0
|
||||
Reference in New Issue
Block a user