feat: Flash Attention + torch.compile(第二版优化方案)
- scaled_dot_product 替换为 F.scaled_dot_product_attention(自动启用 Flash Attention) - load_model 中添加 torch.compile(mode='reduce-overhead') - build_env.sh: 预热 torch inductor,避免编译耗时计入推理
This commit is contained in:
+16
-7
@@ -274,14 +274,18 @@ class RepEncoder(nn.Module):
|
||||
|
||||
|
||||
def scaled_dot_product(q, k, v, extension):
|
||||
d = q.size(-1)
|
||||
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d)
|
||||
"""使用 PyTorch SDPA 后端(自动启用 Flash Attention / Memory Efficient Attention)"""
|
||||
if extension is not None and "mask" in extension:
|
||||
mask = extension["mask"]
|
||||
scores = scores.masked_fill(mask == 0, float("-inf"))
|
||||
attn = torch.softmax(scores, dim=-1)
|
||||
out = torch.matmul(attn, v)
|
||||
return out
|
||||
attn_mask = extension["mask"].to(device=q.device)
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
return F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
)
|
||||
|
||||
|
||||
class Expert(nn.Module):
|
||||
@@ -507,6 +511,11 @@ def load_model(ckpt_path, device='cuda:0'):
|
||||
|
||||
model.to(dev)
|
||||
model.eval()
|
||||
|
||||
# === torch.compile:算子融合 + 减少 kernel launch 开销 ===
|
||||
model = torch.compile(model, mode="reduce-overhead")
|
||||
print("[INFO] torch.compile applied (mode=reduce-overhead)")
|
||||
|
||||
print(f"[INFO] Model ready. Device: {dev}")
|
||||
|
||||
return model, dev
|
||||
|
||||
Reference in New Issue
Block a user