feat: 接口对齐 + FP16 量化(第一版优化方案)
- CTRUserDataset → CTRTestSeqDataset,构造参数对齐评测接口 - load_model 签名修正:ckpt_path 作为第一参数 - FP16 量化:model.half() + Embedding 保留 FP32 - move_batch_to_device 自动 FP32→FP16 转换 - 缓存时预转 FP16,减少推理循环开销 - requirements.txt 精简(去除 nvidia-* 包) - build_env.sh 标准化(set -e + pip install) - CLAUDE.md 更新开发命令、代码架构、关键接口说明
This commit is contained in:
@@ -1,29 +1,5 @@
|
||||
filelock==3.25.2
|
||||
fsspec==2026.2.0
|
||||
Jinja2==3.1.6
|
||||
joblib==1.5.3
|
||||
MarkupSafe==3.0.3
|
||||
mpmath==1.3.0
|
||||
networkx==3.4.2
|
||||
numpy==2.2.6
|
||||
nvidia-cublas-cu12==12.4.5.8
|
||||
nvidia-cuda-cupti-cu12==12.4.127
|
||||
nvidia-cuda-nvrtc-cu12==12.4.127
|
||||
nvidia-cuda-runtime-cu12==12.4.127
|
||||
nvidia-cudnn-cu12==9.1.0.70
|
||||
nvidia-cufft-cu12==11.2.1.3
|
||||
nvidia-curand-cu12==10.3.5.147
|
||||
nvidia-cusolver-cu12==11.6.1.9
|
||||
nvidia-cusparse-cu12==12.3.1.170
|
||||
nvidia-cusparselt-cu12==0.6.2
|
||||
nvidia-nccl-cu12==2.21.5
|
||||
nvidia-nvjitlink-cu12==12.4.127
|
||||
nvidia-nvtx-cu12==12.4.127
|
||||
scikit-learn==1.7.2
|
||||
scipy==1.15.3
|
||||
sympy==1.13.1
|
||||
threadpoolctl==3.6.0
|
||||
torch==2.6.0
|
||||
tqdm==4.67.3
|
||||
triton==3.2.0
|
||||
typing_extensions==4.15.0
|
||||
numpy==2.2.6
|
||||
scikit-learn==1.7.2
|
||||
tqdm==4.67.3
|
||||
|
||||
Reference in New Issue
Block a user