feat: Flash Attention + torch.compile(第二版优化方案)

- scaled_dot_product 替换为 F.scaled_dot_product_attention(自动启用 Flash Attention)
- load_model 中添加 torch.compile(mode='reduce-overhead')
- build_env.sh: 预热 torch inductor,避免编译耗时计入推理
This commit is contained in:
2026-06-12 21:39:43 +08:00
parent 97c4cc84a0
commit 574399e8ac
2 changed files with 28 additions and 9 deletions
+12 -2
View File
@@ -1,7 +1,17 @@
#!/bin/bash
set -e
# 安装 Python 依赖(评测系统使用阿里云 PyPI 镜像)
pip install -r requirements.txt
# 预热 torch inductor,避免推理时编译
python -c "
import torch
@torch.compile(mode='reduce-overhead')
def _warmup(x):
return x * 2
x = torch.randn(100, 100, device='cuda')
_warmup(x)
print('Inductor cache ready')
"
echo "build env success"
+16 -7
View File
@@ -274,14 +274,18 @@ class RepEncoder(nn.Module):
def scaled_dot_product(q, k, v, extension):
d = q.size(-1)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d)
"""使用 PyTorch SDPA 后端(自动启用 Flash Attention / Memory Efficient Attention"""
if extension is not None and "mask" in extension:
mask = extension["mask"]
scores = scores.masked_fill(mask == 0, float("-inf"))
attn = torch.softmax(scores, dim=-1)
out = torch.matmul(attn, v)
return out
attn_mask = extension["mask"].to(device=q.device)
else:
attn_mask = None
return F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=0.0,
is_causal=False,
)
class Expert(nn.Module):
@@ -507,6 +511,11 @@ def load_model(ckpt_path, device='cuda:0'):
model.to(dev)
model.eval()
# === torch.compile:算子融合 + 减少 kernel launch 开销 ===
model = torch.compile(model, mode="reduce-overhead")
print("[INFO] torch.compile applied (mode=reduce-overhead)")
print(f"[INFO] Model ready. Device: {dev}")
return model, dev