feat: 接口对齐 + FP16 量化(第一版优化方案)

- CTRUserDataset → CTRTestSeqDataset,构造参数对齐评测接口
- load_model 签名修正:ckpt_path 作为第一参数
- FP16 量化:model.half() + Embedding 保留 FP32
- move_batch_to_device 自动 FP32→FP16 转换
- 缓存时预转 FP16,减少推理循环开销
- requirements.txt 精简(去除 nvidia-* 包)
- build_env.sh 标准化(set -e + pip install)
- CLAUDE.md 更新开发命令、代码架构、关键接口说明
This commit is contained in:
2026-06-12 20:47:12 +08:00
parent b0ea305ad0
commit 4ee08adff5
4 changed files with 147 additions and 84 deletions
+3 -27
View File
@@ -1,29 +1,5 @@
filelock==3.25.2
fsspec==2026.2.0
Jinja2==3.1.6
joblib==1.5.3
MarkupSafe==3.0.3
mpmath==1.3.0
networkx==3.4.2
numpy==2.2.6
nvidia-cublas-cu12==12.4.5.8
nvidia-cuda-cupti-cu12==12.4.127
nvidia-cuda-nvrtc-cu12==12.4.127
nvidia-cuda-runtime-cu12==12.4.127
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.2.1.3
nvidia-curand-cu12==10.3.5.147
nvidia-cusolver-cu12==11.6.1.9
nvidia-cusparse-cu12==12.3.1.170
nvidia-cusparselt-cu12==0.6.2
nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.4.127
scikit-learn==1.7.2
scipy==1.15.3
sympy==1.13.1
threadpoolctl==3.6.0
torch==2.6.0
tqdm==4.67.3
triton==3.2.0
typing_extensions==4.15.0
numpy==2.2.6
scikit-learn==1.7.2
tqdm==4.67.3