From b72e0346a963002e9157307885417ea87da6cc7a Mon Sep 17 00:00:00 2001 From: OwnerSunshine530 Date: Fri, 19 Jun 2026 20:27:28 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20triton=20attention=20=E8=BE=93=E5=87=BA?= =?UTF-8?q?=E6=8C=89[S,H,Dh]=E5=B8=83=E5=B1=80=E5=86=99,=E6=B6=88=E8=B0=83?= =?UTF-8?q?=E7=94=A8=E6=96=B9permute-clone(x8=E5=B1=82)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit kernel输出stride可配,直接写[1,S,H,Dh]存储,调用方permute(0,2,1,3)变免费视图、 reshape不再clone。纯布局,数值不变。延续减kernel/clone方向。 Co-Authored-By: Claude Opus 4.8 --- 代码/code/infer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/代码/code/infer.py b/代码/code/infer.py index 6077051..29dcac2 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -108,13 +108,16 @@ def _triton_varlen_attn(q, k, v, meta): cu, blk_seq, blk_inseq, total_blocks = meta BLOCK_M = CONFIG.get("triton_block_m", 64) # contiguous 后连续访存更快(实测去 contiguous 用 stride 读反而慢:非连续跨步读 > 一次性 clone)。 - out = torch.empty((1, H, S, Dh), device=q.device, dtype=torch.float16) qc = q.contiguous(); kc = k.contiguous(); vc = v.contiguous() sh, ss, sd = S * Dh, Dh, 1 + # out 存储为 [1,S,H,Dh],以 [1,H,S,Dh] 视图返回 → 调用方 permute(0,2,1,3) 得连续、reshape 免 clone + out_storage = torch.empty((1, S, H, Dh), device=q.device, dtype=torch.float16) + out = out_storage.permute(0, 2, 1, 3) # [1,H,S,Dh] 视图 + soh, sos, sod = out.stride(1), out.stride(2), out.stride(3) # = Dh, H*Dh, 1 grid = (total_blocks, H) _varlen_flash_fwd[grid]( qc, kc, vc, out, cu, blk_seq, blk_inseq, - sh, ss, sd, sh, ss, sd, 1.0 / math.sqrt(Dh), cu.numel() - 1, + sh, ss, sd, soh, sos, sod, 1.0 / math.sqrt(Dh), cu.numel() - 1, BLOCK_M=BLOCK_M, BLOCK_N=64, D=Dh, ) return out