perf: triton attention 输出按[S,H,Dh]布局写,消调用方permute-clone(x8层)

kernel输出stride可配,直接写[1,S,H,Dh]存储,调用方permute(0,2,1,3)变免费视图、
reshape不再clone。纯布局,数值不变。延续减kernel/clone方向。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-19 20:27:28 +08:00
parent 9f73505caa
commit b72e0346a9
+5 -2
View File
@@ -108,13 +108,16 @@ def _triton_varlen_attn(q, k, v, meta):
cu, blk_seq, blk_inseq, total_blocks = meta
BLOCK_M = CONFIG.get("triton_block_m", 64)
# contiguous 后连续访存更快(实测去 contiguous 用 stride 读反而慢:非连续跨步读 > 一次性 clone)。
out = torch.empty((1, H, S, Dh), device=q.device, dtype=torch.float16)
qc = q.contiguous(); kc = k.contiguous(); vc = v.contiguous()
sh, ss, sd = S * Dh, Dh, 1
# out 存储为 [1,S,H,Dh],以 [1,H,S,Dh] 视图返回 → 调用方 permute(0,2,1,3) 得连续、reshape 免 clone
out_storage = torch.empty((1, S, H, Dh), device=q.device, dtype=torch.float16)
out = out_storage.permute(0, 2, 1, 3) # [1,H,S,Dh] 视图
soh, sos, sod = out.stride(1), out.stride(2), out.stride(3) # = Dh, H*Dh, 1
grid = (total_blocks, H)
_varlen_flash_fwd[grid](
qc, kc, vc, out, cu, blk_seq, blk_inseq,
sh, ss, sd, sh, ss, sd, 1.0 / math.sqrt(Dh), cu.numel() - 1,
sh, ss, sd, soh, sos, sod, 1.0 / math.sqrt(Dh), cu.numel() - 1,
BLOCK_M=BLOCK_M, BLOCK_N=64, D=Dh,
)
return out