diff --git a/代码/code/infer.py b/代码/code/infer.py index 6077051..29dcac2 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -108,13 +108,16 @@ def _triton_varlen_attn(q, k, v, meta): cu, blk_seq, blk_inseq, total_blocks = meta BLOCK_M = CONFIG.get("triton_block_m", 64) # contiguous 后连续访存更快(实测去 contiguous 用 stride 读反而慢:非连续跨步读 > 一次性 clone)。 - out = torch.empty((1, H, S, Dh), device=q.device, dtype=torch.float16) qc = q.contiguous(); kc = k.contiguous(); vc = v.contiguous() sh, ss, sd = S * Dh, Dh, 1 + # out 存储为 [1,S,H,Dh],以 [1,H,S,Dh] 视图返回 → 调用方 permute(0,2,1,3) 得连续、reshape 免 clone + out_storage = torch.empty((1, S, H, Dh), device=q.device, dtype=torch.float16) + out = out_storage.permute(0, 2, 1, 3) # [1,H,S,Dh] 视图 + soh, sos, sod = out.stride(1), out.stride(2), out.stride(3) # = Dh, H*Dh, 1 grid = (total_blocks, H) _varlen_flash_fwd[grid]( qc, kc, vc, out, cu, blk_seq, blk_inseq, - sh, ss, sd, sh, ss, sd, 1.0 / math.sqrt(Dh), cu.numel() - 1, + sh, ss, sd, soh, sos, sod, 1.0 / math.sqrt(Dh), cu.numel() - 1, BLOCK_M=BLOCK_M, BLOCK_N=64, D=Dh, ) return out