574399e8ac
- scaled_dot_product 替换为 F.scaled_dot_product_attention(自动启用 Flash Attention) - load_model 中添加 torch.compile(mode='reduce-overhead') - build_env.sh: 预热 torch inductor,避免编译耗时计入推理
18 lines
277 B
Bash
18 lines
277 B
Bash
#!/bin/bash
|
|
set -e
|
|
|
|
# 预热 torch inductor,避免推理时编译
|
|
python -c "
|
|
import torch
|
|
|
|
@torch.compile(mode='reduce-overhead')
|
|
def _warmup(x):
|
|
return x * 2
|
|
|
|
x = torch.randn(100, 100, device='cuda')
|
|
_warmup(x)
|
|
print('Inductor cache ready')
|
|
"
|
|
|
|
echo "build env success"
|