feat: 接口对齐 + FP16 量化(第一版优化方案)
- CTRUserDataset → CTRTestSeqDataset,构造参数对齐评测接口 - load_model 签名修正:ckpt_path 作为第一参数 - FP16 量化:model.half() + Embedding 保留 FP32 - move_batch_to_device 自动 FP32→FP16 转换 - 缓存时预转 FP16,减少推理循环开销 - requirements.txt 精简(去除 nvidia-* 包) - build_env.sh 标准化(set -e + pip install) - CLAUDE.md 更新开发命令、代码架构、关键接口说明
This commit is contained in:
@@ -1,4 +1,7 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# 安装 Python 依赖(评测系统使用阿里云 PyPI 镜像)
|
||||
pip install -r requirements.txt
|
||||
|
||||
echo "build env succeess"
|
||||
echo "build env success"
|
||||
|
||||
+28
-13
@@ -118,15 +118,17 @@ def load_logids_from_file(file_path):
|
||||
return logids
|
||||
|
||||
|
||||
class CTRUserDataset(Dataset):
|
||||
"""按用户组织的 CTR 数据集"""
|
||||
class CTRTestSeqDataset(Dataset):
|
||||
"""按用户组织的 CTR 测试数据集(对齐评测接口)"""
|
||||
|
||||
def __init__(self, item_dict, user_seq=None, max_feasign_per_slot=None, pred_logids=None):
|
||||
def __init__(self, test_logids_ordered, item_dict, user_seq=None,
|
||||
max_feasign_per_slot=None, max_ctx_len=None):
|
||||
super().__init__()
|
||||
self.item_dict = item_dict
|
||||
self.user_seq = user_seq if user_seq else {}
|
||||
self.max_feasign_per_slot = max_feasign_per_slot
|
||||
self.pred_logids = pred_logids if pred_logids is not None else set()
|
||||
self.max_ctx_len = max_ctx_len
|
||||
self.pred_logids = set(test_logids_ordered) if test_logids_ordered else set()
|
||||
|
||||
self.user_items = defaultdict(list)
|
||||
for logid, rec in item_dict.items():
|
||||
@@ -236,7 +238,11 @@ def move_batch_to_device(batch, device):
|
||||
elif isinstance(batch, (list, tuple)):
|
||||
return [move_batch_to_device(x, device) for x in batch]
|
||||
elif torch.is_tensor(batch):
|
||||
return batch.to(device)
|
||||
x = batch.to(device)
|
||||
# 浮点 tensor → FP16,整数 tensor 保持不变
|
||||
if x.dtype == torch.float32:
|
||||
x = x.half()
|
||||
return x
|
||||
else:
|
||||
return batch
|
||||
|
||||
@@ -443,12 +449,12 @@ class CTRModel(nn.Module):
|
||||
# 模型加载入口
|
||||
# ============================================================
|
||||
|
||||
def load_model(device='cuda:0', ckpt_path=None):
|
||||
def load_model(ckpt_path, device='cuda:0'):
|
||||
"""加载模型并返回,供 evaluation.py 调用。
|
||||
|
||||
Args:
|
||||
ckpt_path: checkpoint 文件路径(评测系统传入 Path 对象)
|
||||
device: 推理设备(默认 'cuda:0')
|
||||
ckpt_path: checkpoint 文件路径,默认使用 infer.py 同目录下的 ckpt.pt
|
||||
|
||||
Returns:
|
||||
(model, device) 元组
|
||||
@@ -490,6 +496,11 @@ def load_model(device='cuda:0', ckpt_path=None):
|
||||
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
|
||||
model.load_state_dict(ckpt['model_state_dict'])
|
||||
print(f"[INFO] Loaded checkpoint from {ckpt_path} (epoch={ckpt.get('epoch', '?')})")
|
||||
|
||||
# === FP16 量化:模型参数转半精度,Embedding 保留 FP32 ===
|
||||
model = model.half()
|
||||
model.rep_encoder.emb = model.rep_encoder.emb.to(torch.float32)
|
||||
print("[INFO] Model converted to FP16 (embedding kept in FP32)")
|
||||
else:
|
||||
print(f"[WARNING] Checkpoint {ckpt_path} not found, using random weights")
|
||||
|
||||
@@ -616,10 +627,11 @@ def main():
|
||||
print(f'[INFO] Test pred logids count: {len(test_pred_logids)}')
|
||||
|
||||
max_feasign_per_slot = {1: 2}
|
||||
test_dataset = CTRUserDataset(
|
||||
item_dict, user_seq,
|
||||
test_dataset = CTRTestSeqDataset(
|
||||
test_logids_ordered=list(test_pred_logids),
|
||||
item_dict=item_dict,
|
||||
user_seq=user_seq,
|
||||
max_feasign_per_slot=max_feasign_per_slot,
|
||||
pred_logids=test_pred_logids,
|
||||
)
|
||||
print(f'[INFO] num_users={test_dataset.num_users}, '
|
||||
f'total_samples={test_dataset.total_samples}, '
|
||||
@@ -634,9 +646,12 @@ def main():
|
||||
collate_fn=make_collate_fn(test_dataset.max_slot_id),
|
||||
)
|
||||
|
||||
# 收集 batches 并按分片缓存
|
||||
print('[INFO] collecting batches and saving sharded cache...')
|
||||
all_batches = [batch for batch in test_loader]
|
||||
# 收集 batches,预转 FP16 后按分片缓存
|
||||
print('[INFO] collecting batches (pre-converting to FP16) and saving sharded cache...')
|
||||
all_batches = []
|
||||
for batch in test_loader:
|
||||
batch = move_batch_to_device(batch, torch.device('cpu'))
|
||||
all_batches.append(batch)
|
||||
|
||||
batches_cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
shard_idx = 0
|
||||
|
||||
@@ -1,29 +1,5 @@
|
||||
filelock==3.25.2
|
||||
fsspec==2026.2.0
|
||||
Jinja2==3.1.6
|
||||
joblib==1.5.3
|
||||
MarkupSafe==3.0.3
|
||||
mpmath==1.3.0
|
||||
networkx==3.4.2
|
||||
numpy==2.2.6
|
||||
nvidia-cublas-cu12==12.4.5.8
|
||||
nvidia-cuda-cupti-cu12==12.4.127
|
||||
nvidia-cuda-nvrtc-cu12==12.4.127
|
||||
nvidia-cuda-runtime-cu12==12.4.127
|
||||
nvidia-cudnn-cu12==9.1.0.70
|
||||
nvidia-cufft-cu12==11.2.1.3
|
||||
nvidia-curand-cu12==10.3.5.147
|
||||
nvidia-cusolver-cu12==11.6.1.9
|
||||
nvidia-cusparse-cu12==12.3.1.170
|
||||
nvidia-cusparselt-cu12==0.6.2
|
||||
nvidia-nccl-cu12==2.21.5
|
||||
nvidia-nvjitlink-cu12==12.4.127
|
||||
nvidia-nvtx-cu12==12.4.127
|
||||
scikit-learn==1.7.2
|
||||
scipy==1.15.3
|
||||
sympy==1.13.1
|
||||
threadpoolctl==3.6.0
|
||||
torch==2.6.0
|
||||
tqdm==4.67.3
|
||||
triton==3.2.0
|
||||
typing_extensions==4.15.0
|
||||
numpy==2.2.6
|
||||
scikit-learn==1.7.2
|
||||
tqdm==4.67.3
|
||||
|
||||
Reference in New Issue
Block a user