c081620ffd
- 仅路由到 Top-1 expert(节省 50% FFN 计算) - gate 输出 top-2 概率,用 p1+p2 作为输出权重 - 近似 k=2 的输出幅度,避免 PCOC 偏移 - 是参数调整修正,非方案本身错误
743 lines
26 KiB
Python
743 lines
26 KiB
Python
import sys
|
||
import os
|
||
|
||
# 获取当前环境脚本所在目录或指定绝对路径
|
||
if os.path.exists("../libraries"):
|
||
lib_path = os.path.abspath("../libraries")
|
||
sys.path.append(lib_path)
|
||
|
||
import math
|
||
import argparse
|
||
from pathlib import Path
|
||
from collections import defaultdict
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from torch.utils.data import Dataset, DataLoader
|
||
from tqdm import tqdm
|
||
|
||
|
||
# ============================================================
|
||
# 数据加载(来自 train/dataset.py)
|
||
# ============================================================
|
||
|
||
def _detect_has_clk(file_path):
|
||
"""检测 CSV 文件是否包含 clk 列(5列 vs 4列格式)。
|
||
5列格式: logid,userid,adid,clk,timestamp,sign:slot...
|
||
4列格式: logid,userid,adid,timestamp,sign:slot...
|
||
通过第5个字段是否包含 ':' 来判断:有 ':' 说明已经是 sign:slot,即无 clk 列。
|
||
"""
|
||
with open(file_path, 'r') as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
parts = line.split(',')
|
||
if len(parts) >= 5:
|
||
return ':' not in parts[4]
|
||
return False
|
||
return False
|
||
|
||
|
||
def load_sample_files(sample_files_list):
|
||
"""加载 CSV sample 文件,返回 item_dict 和 user_seq。
|
||
自动检测每个文件是 5列(含clk)还是 4列(无clk)格式。
|
||
"""
|
||
sample_files = sorted([Path(f) for f in sample_files_list])
|
||
print(f'[INFO] loading {len(sample_files)} files: {[str(f) for f in sample_files]}')
|
||
|
||
item_dict = {}
|
||
user_logs = defaultdict(list)
|
||
|
||
for sample_file in tqdm(sample_files, desc='Loading sample files'):
|
||
has_clk = _detect_has_clk(sample_file)
|
||
min_parts = 5 if has_clk else 4
|
||
print(f' {sample_file.name}: has_clk={has_clk}')
|
||
|
||
with open(sample_file, 'r') as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
parts = line.split(',')
|
||
if len(parts) < min_parts:
|
||
continue
|
||
|
||
logid = int(parts[0])
|
||
userid = int(parts[1])
|
||
adid = int(parts[2])
|
||
|
||
if has_clk:
|
||
clk = int(parts[3])
|
||
timestamp = int(parts[4])
|
||
feat_start = 5
|
||
else:
|
||
clk = 0
|
||
timestamp = int(parts[3])
|
||
feat_start = 4
|
||
|
||
signs = []
|
||
slots = []
|
||
for pair in parts[feat_start:]:
|
||
if ':' in pair:
|
||
s, sl = pair.split(':', 1)
|
||
signs.append(int(s))
|
||
slots.append(int(sl))
|
||
|
||
item_dict[logid] = {
|
||
'logid': logid,
|
||
'userid': userid,
|
||
'adid': adid,
|
||
'clk': clk,
|
||
'timestamp': timestamp,
|
||
'signs': np.array(signs, dtype=np.int64),
|
||
'slots': np.array(slots, dtype=np.int64),
|
||
}
|
||
user_logs[userid].append((timestamp, logid))
|
||
|
||
user_seq = {}
|
||
for userid, logs in user_logs.items():
|
||
logs.sort(key=lambda x: x[0])
|
||
user_seq[userid] = [logid for _, logid in logs]
|
||
|
||
print(f'[INFO] loaded {len(item_dict)} records, {len(user_seq)} users')
|
||
return item_dict, user_seq
|
||
|
||
|
||
def load_logids_from_file(file_path):
|
||
"""快速读取一个 sample 文件中的所有 logid"""
|
||
logids = set()
|
||
with open(file_path, 'r') as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
comma = line.index(',')
|
||
logids.add(int(line[:comma]))
|
||
return logids
|
||
|
||
|
||
class CTRTestSeqDataset(Dataset):
|
||
"""按用户组织的 CTR 测试数据集(对齐评测接口)"""
|
||
|
||
def __init__(self, test_logids_ordered, item_dict, user_seq=None,
|
||
max_feasign_per_slot=None, max_ctx_len=None):
|
||
super().__init__()
|
||
self.item_dict = item_dict
|
||
self.user_seq = user_seq if user_seq else {}
|
||
self.max_feasign_per_slot = max_feasign_per_slot
|
||
self.max_ctx_len = max_ctx_len
|
||
self.pred_logids = set(test_logids_ordered) if test_logids_ordered else set()
|
||
|
||
self.user_items = defaultdict(list)
|
||
for logid, rec in item_dict.items():
|
||
userid = rec['userid']
|
||
feasign = defaultdict(list)
|
||
for slot, sign in zip(rec['slots'].tolist(), rec['signs'].tolist()):
|
||
feasign[slot].append(sign)
|
||
if max_feasign_per_slot is not None:
|
||
feasign = {slot: signs[:max_feasign_per_slot[slot]]
|
||
if max_feasign_per_slot.get(slot, -1) != -1 else signs
|
||
for slot, signs in feasign.items()}
|
||
feasign = dict(feasign)
|
||
label = rec['clk']
|
||
self.user_items[userid].append((logid, feasign, label))
|
||
|
||
self.user_ids = sorted(self.user_items.keys())
|
||
self.num_users = len(self.user_ids)
|
||
self.total_samples = len(item_dict)
|
||
|
||
all_signs = set()
|
||
for rec in item_dict.values():
|
||
all_signs.update(rec['signs'].tolist())
|
||
self.max_slot_id = 28
|
||
self.max_sign_id = max(all_signs) if all_signs else 0
|
||
|
||
def __len__(self):
|
||
return self.num_users
|
||
|
||
def __getitem__(self, index):
|
||
userid = self.user_ids[index]
|
||
items = self.user_items[userid]
|
||
|
||
if self.user_seq and userid in self.user_seq:
|
||
seq_order = {logid: i for i, logid in enumerate(self.user_seq[userid])}
|
||
items.sort(key=lambda x: seq_order.get(x[0], x[0]))
|
||
else:
|
||
items.sort(key=lambda x: x[0])
|
||
|
||
feasigns = []
|
||
labels = []
|
||
logids = []
|
||
for logid, feasign, label in items:
|
||
logids.append(logid)
|
||
feasigns.append(feasign)
|
||
labels.append(label)
|
||
|
||
return {
|
||
'userid': userid,
|
||
'logids': logids,
|
||
'feasigns': feasigns,
|
||
'labels': labels,
|
||
'pred_mask': [1 if logid in self.pred_logids else 0 for logid in logids],
|
||
}
|
||
|
||
|
||
def make_collate_fn(max_slot_id):
|
||
def collate_user_batch(batch):
|
||
all_userids = []
|
||
all_logids = []
|
||
all_labels = []
|
||
all_pred_masks = []
|
||
all_feasigns = []
|
||
user_offsets = [0]
|
||
|
||
for item in batch:
|
||
for i, logid in enumerate(item['logids']):
|
||
all_userids.append(item['userid'])
|
||
all_logids.append(logid)
|
||
all_labels.append(item['labels'][i])
|
||
all_pred_masks.append(item['pred_mask'][i])
|
||
all_feasigns.append(item['feasigns'][i])
|
||
user_offsets.append(len(all_labels))
|
||
|
||
slot_data = {}
|
||
for slot in range(1, max_slot_id + 1):
|
||
values = []
|
||
offsets = [0]
|
||
for feasign in all_feasigns:
|
||
if slot in feasign:
|
||
values.extend(feasign[slot])
|
||
offsets.append(len(values))
|
||
slot_data[slot] = (
|
||
torch.tensor(values, dtype=torch.long),
|
||
torch.tensor(offsets, dtype=torch.long),
|
||
)
|
||
|
||
result = {
|
||
'userid': torch.tensor(all_userids, dtype=torch.long),
|
||
'logid': torch.tensor(all_logids, dtype=torch.long),
|
||
'label': torch.tensor(all_labels, dtype=torch.float32),
|
||
'pred_mask': torch.tensor(all_pred_masks, dtype=torch.bool),
|
||
'user_offsets': torch.tensor(user_offsets, dtype=torch.long),
|
||
}
|
||
result.update(slot_data)
|
||
return result
|
||
|
||
return collate_user_batch
|
||
|
||
|
||
# ============================================================
|
||
# 模型定义(来自 main.py)
|
||
# ============================================================
|
||
|
||
def move_batch_to_device(batch, device):
|
||
if isinstance(batch, dict):
|
||
return {k: move_batch_to_device(v, device) for k, v in batch.items()}
|
||
elif isinstance(batch, (list, tuple)):
|
||
return [move_batch_to_device(x, device) for x in batch]
|
||
elif torch.is_tensor(batch):
|
||
x = batch.to(device)
|
||
# 浮点 tensor → FP16,整数 tensor 保持不变
|
||
if x.dtype == torch.float32:
|
||
x = x.half()
|
||
return x
|
||
else:
|
||
return batch
|
||
|
||
|
||
class RepEncoder(nn.Module):
|
||
def __init__(self, vocab_size, emb_dim, padding_idx=0, slot_num=0, d_model=0):
|
||
super().__init__()
|
||
self.emb = nn.Embedding(num_embeddings=vocab_size, embedding_dim=emb_dim, padding_idx=padding_idx)
|
||
self.emb_dim = emb_dim
|
||
self.slot_num = slot_num
|
||
self.input_norm = nn.LayerNorm(slot_num * emb_dim)
|
||
self.linear = nn.Linear(in_features=slot_num * emb_dim, out_features=d_model)
|
||
|
||
def forward(self, batch):
|
||
pooled_embs = []
|
||
max_idx = self.emb.num_embeddings - 1
|
||
target_dtype = self.input_norm.weight.dtype # 后续层 dtype(FP16 时为 torch.float16)
|
||
for i in range(self.slot_num):
|
||
values, offsets = batch[i + 1]
|
||
offsets = offsets.to(values.device)
|
||
values = values.clamp(0, max_idx) # 超出 vocab_size 的 sign id 截断,避免越界
|
||
sign_emb = self.emb(values).to(target_dtype)
|
||
res = torch.segment_reduce(sign_emb, reduce='sum', offsets=offsets, initial=0)
|
||
pooled_embs.append(res)
|
||
fused_embs = torch.cat(pooled_embs, dim=1)
|
||
norm_emb = self.input_norm(fused_embs)
|
||
rep_emb = self.linear(norm_emb)
|
||
return rep_emb
|
||
|
||
|
||
def scaled_dot_product(q, k, v, extension):
|
||
"""使用 PyTorch SDPA 后端(自动启用 Flash Attention / Memory Efficient Attention)"""
|
||
if extension is not None and "mask" in extension:
|
||
attn_mask = extension["mask"].to(device=q.device)
|
||
else:
|
||
attn_mask = None
|
||
|
||
return F.scaled_dot_product_attention(
|
||
q, k, v,
|
||
attn_mask=attn_mask,
|
||
dropout_p=0.0,
|
||
is_causal=False,
|
||
)
|
||
|
||
|
||
class Expert(nn.Module):
|
||
def __init__(self, d_model, dim_ff):
|
||
super().__init__()
|
||
self.fc1 = nn.Linear(d_model, dim_ff)
|
||
self.fc2 = nn.Linear(dim_ff, d_model)
|
||
|
||
def forward(self, x):
|
||
return self.fc2(F.relu(self.fc1(x)))
|
||
|
||
|
||
class TopKGate(nn.Module):
|
||
def __init__(self, d_model, num_experts, k=2, noisy_gating=True):
|
||
super().__init__()
|
||
self.w_g = nn.Linear(d_model, num_experts)
|
||
self.num_experts = num_experts
|
||
self.k = k
|
||
self.noisy_gating = noisy_gating
|
||
|
||
def forward(self, x):
|
||
# x: [B,S,D]
|
||
logits = self.w_g(x) # [B,S,E]
|
||
|
||
if self.noisy_gating and self.training:
|
||
logits = logits + torch.randn_like(logits) * 0.1
|
||
|
||
probs = torch.softmax(logits, dim=-1) # [B,S,E]
|
||
|
||
topk_score, topk_idx = torch.topk(probs, self.k, dim=-1) # [B,S,k]
|
||
|
||
return topk_idx, topk_score, probs
|
||
|
||
class SMoE(nn.Module):
|
||
def __init__(self, d_model, dim_ff, num_experts, k=1):
|
||
super().__init__()
|
||
self.num_experts = num_experts
|
||
self.k = k
|
||
|
||
self.experts = nn.ModuleList([
|
||
Expert(d_model, dim_ff) for _ in range(num_experts)
|
||
])
|
||
|
||
self.gate = TopKGate(d_model, num_experts, k=2) # gate 内部用 k=2 获取补偿权重
|
||
|
||
def forward(self, x):
|
||
# x: [B,S,D]
|
||
B, S, D = x.shape
|
||
|
||
topk_idx, topk_score, probs = self.gate(x)
|
||
# topk_idx: [B, S, 2], topk_score: [B, S, 2]
|
||
|
||
# 仅路由到 Top-1 expert,但用 (p1+p2) 作为权重补偿
|
||
route_idx = topk_idx[:, :, :1] # [B, S, 1] — 只取 top-1
|
||
weight_sum = topk_score.sum(dim=-1) # [B, S] — p1 + p2 作为总权重
|
||
|
||
out = torch.zeros_like(x)
|
||
|
||
x_flat = x.reshape(-1, D)
|
||
idx_flat = route_idx.reshape(-1) # [B*S]
|
||
weight_flat = weight_sum.reshape(-1) # [B*S]
|
||
out_flat = out.reshape(-1, D)
|
||
|
||
for i in range(self.num_experts):
|
||
mask = (idx_flat == i) # [B*S]
|
||
token_idx = mask.nonzero(as_tuple=True)[0]
|
||
if token_idx.numel() == 0:
|
||
continue
|
||
selected_x = x_flat[token_idx]
|
||
expert_out = self.experts[i](selected_x)
|
||
out_flat[token_idx] = expert_out * weight_flat[token_idx].unsqueeze(-1)
|
||
|
||
importance = probs.sum(dim=(0,1)) # [E]
|
||
moe_loss = (importance.std() / (importance.mean() + 1e-6))
|
||
|
||
return out, moe_loss
|
||
|
||
|
||
class TransformerEncoder(nn.Module):
|
||
def __init__(self, d_model, n_heads, num_layers, dim_ff, act="relu",
|
||
attention_fn=scaled_dot_product):
|
||
super().__init__()
|
||
self.d_model = d_model
|
||
self.n_heads = n_heads
|
||
self.head_dim = d_model // n_heads
|
||
self.num_layers = num_layers
|
||
assert d_model % n_heads == 0
|
||
|
||
self.qkv_proj = nn.ModuleList([nn.Linear(d_model, 3 * d_model) for _ in range(num_layers)])
|
||
self.out_proj = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(num_layers)])
|
||
self.ffn1 = nn.ModuleList([nn.Linear(d_model, dim_ff) for _ in range(num_layers)])
|
||
self.ffn2 = nn.ModuleList([nn.Linear(dim_ff, d_model) for _ in range(num_layers)])
|
||
self.norm1 = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(num_layers)])
|
||
self.norm2 = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(num_layers)])
|
||
self.act = getattr(F, act)
|
||
self.attention_fn = attention_fn
|
||
self.moe = nn.ModuleList([
|
||
SMoE(d_model, dim_ff, num_experts=8, k=1) # Top-1 路由 + (p1+p2) 权重补偿
|
||
for _ in range(num_layers)
|
||
])
|
||
|
||
def forward(self, x, extension):
|
||
x = x.unsqueeze(0)
|
||
B, S, D = x.shape
|
||
|
||
moe_loss_total = 0.0
|
||
for i in range(self.num_layers):
|
||
residual = x
|
||
x = self.norm1[i](x)
|
||
qkv = self.qkv_proj[i](x)
|
||
qkv = qkv.view(B, S, self.n_heads, 3 * self.head_dim)
|
||
qkv = qkv.permute(0, 2, 1, 3)
|
||
q, k, v = torch.split(qkv, self.head_dim, dim=-1)
|
||
attn_out = self.attention_fn(q, k, v, extension)
|
||
attn_out = attn_out.permute(0, 2, 1, 3).reshape(B, S, D)
|
||
x = residual + self.out_proj[i](attn_out)
|
||
residual = x
|
||
x = self.norm2[i](x)
|
||
|
||
moe_out, moe_loss = self.moe[i](x)
|
||
|
||
x = residual + moe_out
|
||
|
||
moe_loss_total = moe_loss_total + moe_loss
|
||
|
||
return x, moe_loss_total
|
||
|
||
|
||
class CTRModel(nn.Module):
|
||
def __init__(self, rep_encoder, seq_encoder, d_model):
|
||
super().__init__()
|
||
self.rep_encoder = rep_encoder
|
||
self.seq_encoder = seq_encoder
|
||
self.d_model = d_model
|
||
self.linear = nn.Linear(d_model, 1)
|
||
|
||
def get_sequence_causal_mask(self, seq_info):
|
||
lengths = seq_info[1:] - seq_info[:-1]
|
||
lengths = lengths.view(-1)
|
||
indices = torch.cumsum(torch.ones_like(lengths), dim=0) - 1
|
||
result = torch.repeat_interleave(indices, lengths)
|
||
a = result.view(1, -1) - result.view(-1, 1)
|
||
out_mask = torch.tril((a == 0).to(torch.int32)).bool()
|
||
return out_mask
|
||
|
||
def forward(self, batch):
|
||
seq_input = self.rep_encoder(batch)
|
||
seq_mask = self.get_sequence_causal_mask(batch["user_offsets"])
|
||
encoder_output, moe_loss = self.seq_encoder(
|
||
x=seq_input,
|
||
extension={"mask": seq_mask.unsqueeze(0).unsqueeze(0)},
|
||
)
|
||
encoder_output = encoder_output.squeeze(0)
|
||
pred = self.linear(encoder_output)
|
||
pred_logits = torch.clamp(pred, min=-15.0, max=15.0)
|
||
return pred_logits, moe_loss
|
||
|
||
|
||
# ============================================================
|
||
# 模型加载入口
|
||
# ============================================================
|
||
|
||
def load_model(ckpt_path, device='cuda:0'):
|
||
"""加载模型并返回,供 evaluation.py 调用。
|
||
|
||
Args:
|
||
ckpt_path: checkpoint 文件路径(评测系统传入 Path 对象)
|
||
device: 推理设备(默认 'cuda:0')
|
||
|
||
Returns:
|
||
(model, device) 元组
|
||
"""
|
||
emb_dim = 512
|
||
slot_num = 28
|
||
vocab_size = 5000000
|
||
d_model = 512
|
||
n_heads = 8
|
||
num_layers = 8
|
||
dim_ff = 1024
|
||
|
||
rep_encoder = RepEncoder(
|
||
vocab_size=vocab_size,
|
||
emb_dim=emb_dim,
|
||
padding_idx=0,
|
||
slot_num=slot_num,
|
||
d_model=d_model,
|
||
)
|
||
seq_encoder = TransformerEncoder(
|
||
d_model=d_model,
|
||
n_heads=n_heads,
|
||
num_layers=num_layers,
|
||
dim_ff=dim_ff,
|
||
act="relu",
|
||
)
|
||
model = CTRModel(rep_encoder, seq_encoder, d_model=d_model)
|
||
|
||
dev = torch.device(device if torch.cuda.is_available() else "cpu")
|
||
|
||
# 加载 checkpoint
|
||
# 若需要加载自定义修改的权重,请修改 479-488行逻辑,强制使用你文件夹中的权重
|
||
# 测评系统默认使用原始官方权重
|
||
if ckpt_path is None:
|
||
ckpt_path = Path(__file__).parent / 'ckpt.pt'
|
||
else:
|
||
ckpt_path = Path(ckpt_path)
|
||
if ckpt_path.exists():
|
||
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
|
||
model.load_state_dict(ckpt['model_state_dict'])
|
||
print(f"[INFO] Loaded checkpoint from {ckpt_path} (epoch={ckpt.get('epoch', '?')})")
|
||
|
||
# === FP16 量化:模型参数转半精度,Embedding 保留 FP32 ===
|
||
model = model.half()
|
||
model.rep_encoder.emb = model.rep_encoder.emb.to(torch.float32)
|
||
print("[INFO] Model converted to FP16 (embedding kept in FP32)")
|
||
else:
|
||
print(f"[WARNING] Checkpoint {ckpt_path} not found, using random weights")
|
||
|
||
model.to(dev)
|
||
model.eval()
|
||
print(f"[INFO] Model ready. Device: {dev}")
|
||
|
||
return model, dev
|
||
|
||
|
||
# ============================================================
|
||
# 打分工具(与 evaluation.py 保持一致)
|
||
# ============================================================
|
||
|
||
def _read_predict(file_path):
|
||
predictions = []
|
||
with open(file_path, 'r') as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if line:
|
||
predictions.append(float(line))
|
||
import numpy as np
|
||
return np.array(predictions)
|
||
|
||
|
||
def _read_label(file_path):
|
||
labels = []
|
||
with open(file_path, 'r') as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if line:
|
||
parts = line.split(',')
|
||
if len(parts) >= 4:
|
||
labels.append(float(parts[3]))
|
||
else:
|
||
labels.append(float(line))
|
||
import numpy as np
|
||
return np.array(labels)
|
||
|
||
|
||
def _cal_score(predict_file, label_file, default_latency=0.0):
|
||
import numpy as np
|
||
from sklearn.metrics import roc_auc_score
|
||
|
||
predictions = _read_predict(predict_file)
|
||
labels = _read_label(label_file)
|
||
|
||
unique_labels = np.unique(labels)
|
||
if len(unique_labels) < 2:
|
||
print('[WARNING] only one class present in labels, AUC is not defined, returning 0.5')
|
||
auc = 0.5
|
||
else:
|
||
auc = roc_auc_score(labels, predictions)
|
||
|
||
mean_pred = np.mean(predictions)
|
||
mean_label = np.mean(labels)
|
||
if mean_label == 0:
|
||
pcoc = 1.0 if mean_pred == 0 else float('inf')
|
||
else:
|
||
pcoc = float(mean_pred / mean_label)
|
||
|
||
latency = default_latency
|
||
base_latency = 300
|
||
score_latency = max(0.0, (base_latency - latency) / base_latency) if latency < base_latency else 0.0
|
||
|
||
if pcoc < 0.85 or pcoc > 1.15:
|
||
score_model = 0.0
|
||
else:
|
||
score_model = ((auc - 0.65) * 1000 + (0.15 - abs(pcoc - 1)) / 0.15 * 10) / 360
|
||
|
||
score_all = score_latency * 70 + score_model * 30
|
||
|
||
return {
|
||
'auc': auc,
|
||
'pcoc': pcoc,
|
||
'latency': latency,
|
||
'score_latency': score_latency,
|
||
'score_model': score_model,
|
||
'score_all': score_all,
|
||
}
|
||
|
||
|
||
# ============================================================
|
||
# main:直接运行 infer.py 进行测试
|
||
# ============================================================
|
||
|
||
def main():
|
||
import io
|
||
import time
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument('--ckpt', type=str, default=None, help='checkpoint 文件路径,默认使用同目录下的 ckpt.pt')
|
||
args = parser.parse_args()
|
||
|
||
cur_path = Path(__file__).parent.absolute()
|
||
ref_dir = cur_path / 'dataset'
|
||
history_dir = ref_dir / 'history'
|
||
input_file = ref_dir / 'test.csv'
|
||
output_file = Path('predict.txt')
|
||
label_file = ref_dir / 'label_data.txt'
|
||
|
||
# ----- 数据加载,优先从缓存读取 -----
|
||
MAX_SHARD_BYTES = 2 * 1024 * 1024 * 1024 # 2GB per shard
|
||
batches_cache_dir = ref_dir / 'cached_batches'
|
||
|
||
if batches_cache_dir.exists() and any(batches_cache_dir.glob('shard_*.pt')):
|
||
print(f'[INFO] loading cached batch shards from {batches_cache_dir}')
|
||
all_batches = []
|
||
shard_files = sorted(batches_cache_dir.glob('shard_*.pt'),
|
||
key=lambda p: int(p.stem.split('_')[1]))
|
||
for sf in shard_files:
|
||
shard_batches = torch.load(sf, weights_only=False)
|
||
all_batches.extend(shard_batches)
|
||
print(f'[INFO] loaded {len(shard_batches)} batches from {sf.name}')
|
||
print(f'[INFO] loaded {len(all_batches)} cached batches total from {len(shard_files)} shards')
|
||
else:
|
||
print('[INFO] start loading data from CSV')
|
||
history_files = sorted(history_dir.glob('*.csv')) if history_dir.exists() else []
|
||
all_files = history_files + [input_file]
|
||
|
||
item_dict, user_seq = load_sample_files(sample_files_list=all_files)
|
||
test_pred_logids = load_logids_from_file(input_file)
|
||
print(f'[INFO] Test pred logids count: {len(test_pred_logids)}')
|
||
|
||
max_feasign_per_slot = {1: 2}
|
||
test_dataset = CTRTestSeqDataset(
|
||
test_logids_ordered=list(test_pred_logids),
|
||
item_dict=item_dict,
|
||
user_seq=user_seq,
|
||
max_feasign_per_slot=max_feasign_per_slot,
|
||
)
|
||
print(f'[INFO] num_users={test_dataset.num_users}, '
|
||
f'total_samples={test_dataset.total_samples}, '
|
||
f'pred_samples={len(test_pred_logids)}, '
|
||
f'max_sign_id={test_dataset.max_sign_id}')
|
||
|
||
test_loader = DataLoader(
|
||
test_dataset,
|
||
batch_size=50,
|
||
shuffle=False,
|
||
num_workers=0,
|
||
collate_fn=make_collate_fn(test_dataset.max_slot_id),
|
||
)
|
||
|
||
# 收集 batches,预转 FP16 后按分片缓存
|
||
print('[INFO] collecting batches (pre-converting to FP16) and saving sharded cache...')
|
||
all_batches = []
|
||
for batch in test_loader:
|
||
batch = move_batch_to_device(batch, torch.device('cpu'))
|
||
all_batches.append(batch)
|
||
|
||
batches_cache_dir.mkdir(parents=True, exist_ok=True)
|
||
shard_idx = 0
|
||
current_shard = []
|
||
current_size = 0
|
||
for batch in all_batches:
|
||
buf = io.BytesIO()
|
||
torch.save(batch, buf)
|
||
batch_size_bytes = buf.tell()
|
||
if current_shard and current_size + batch_size_bytes > MAX_SHARD_BYTES:
|
||
shard_path = batches_cache_dir / f'shard_{shard_idx:04d}.pt'
|
||
torch.save(current_shard, shard_path)
|
||
print(f'[INFO] saved shard {shard_path.name}: {len(current_shard)} batches, '
|
||
f'~{current_size / 1024**3:.2f}GB')
|
||
shard_idx += 1
|
||
current_shard = []
|
||
current_size = 0
|
||
current_shard.append(batch)
|
||
current_size += batch_size_bytes
|
||
if current_shard:
|
||
shard_path = batches_cache_dir / f'shard_{shard_idx:04d}.pt'
|
||
torch.save(current_shard, shard_path)
|
||
print(f'[INFO] saved shard {shard_path.name}: {len(current_shard)} batches, '
|
||
f'~{current_size / 1024**3:.2f}GB')
|
||
shard_idx += 1
|
||
print(f'[INFO] saved {len(all_batches)} batches to {shard_idx} shards in {batches_cache_dir}')
|
||
|
||
print('[INFO] data loading done')
|
||
|
||
# ----- 加载模型 -----
|
||
model, dev = load_model(ckpt_path=args.ckpt)
|
||
|
||
# ----- 推理 -----
|
||
print('*' * 20 + ' start inference ' + '*' * 20)
|
||
all_logids = []
|
||
all_probs = []
|
||
time_sum = 0.0
|
||
|
||
with torch.inference_mode():
|
||
for batch in tqdm(all_batches, desc="Inference"):
|
||
batch = move_batch_to_device(batch, dev)
|
||
pred_mask = batch["pred_mask"].bool()
|
||
|
||
t_start = time.time()
|
||
logits, moe_loss = model(batch)
|
||
logits = logits.squeeze(-1)
|
||
probs = torch.sigmoid(logits)
|
||
time_sum += time.time() - t_start
|
||
|
||
masked_logids = batch["logid"][pred_mask].cpu().tolist()
|
||
masked_probs = probs[pred_mask].cpu().tolist()
|
||
all_logids.extend(masked_logids)
|
||
all_probs.extend(masked_probs)
|
||
|
||
print(f'[INFO] inference time: {round(time_sum, 4)}s')
|
||
print('*' * 20 + ' end inference ' + '*' * 20)
|
||
|
||
# ----- 按 test.csv 顺序写预测文件 -----
|
||
logid_to_prob = dict(zip(all_logids, all_probs))
|
||
test_logids_in_order = []
|
||
with open(input_file, 'r') as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if line:
|
||
test_logids_in_order.append(int(line.split(',')[0]))
|
||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||
with open(output_file, 'w') as f:
|
||
for logid in test_logids_in_order:
|
||
f.write(f"{logid_to_prob[logid]}\n")
|
||
print(f'[INFO] predictions written to {output_file}, total: {len(test_logids_in_order)}')
|
||
|
||
# ----- 打分 -----
|
||
if label_file.exists():
|
||
result = _cal_score(output_file, label_file, default_latency=time_sum)
|
||
print(f'[INFO] AUC: {result["auc"]:.6f}')
|
||
print(f'[INFO] PCOC: {result["pcoc"]:.6f}')
|
||
print(f'[INFO] Latency: {result["latency"]:.4f}s')
|
||
print(f'[INFO] score_latency: {result["score_latency"]:.6f}')
|
||
print(f'[INFO] score_model: {result["score_model"]:.6f}')
|
||
print(f'[INFO] score_all: {result["score_all"]:.6f}')
|
||
return result
|
||
else:
|
||
print(f'[WARNING] label file {label_file} not found, skipping scoring')
|
||
return None
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|
||
|