diff --git a/代码/code/infer.py b/代码/code/infer.py index f66540f..8c4b1d2 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -42,7 +42,7 @@ if _HAS_TRITON: def _varlen_flash_fwd( Q, K, V, Out, cu_seqlens, blk_seq, blk_inseq, - stride_h, stride_s, stride_d, + sqh, sqs, sqd, soh, sos, sod, scale, n_seq, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, D: tl.constexpr, ): @@ -57,7 +57,7 @@ if _HAS_TRITON: offs_m = q_row0 + tl.arange(0, BLOCK_M) # query token 全局行号 offs_d = tl.arange(0, D) q_mask = offs_m < seq_end - q_ptrs = Q + h * stride_h + offs_m[:, None] * stride_s + offs_d[None, :] * stride_d + q_ptrs = Q + h * sqh + offs_m[:, None] * sqs + offs_d[None, :] * sqd q = tl.load(q_ptrs, mask=q_mask[:, None], other=0.0) # 保持 fp16,dot 走 Tensor Core m_i = tl.full([BLOCK_M], -float("inf"), tl.float32) @@ -69,7 +69,7 @@ if _HAS_TRITON: for kn in range(seq_start, kv_end, BLOCK_N): offs_n = kn + tl.arange(0, BLOCK_N) k_mask = offs_n < seq_end - k_ptrs = K + h * stride_h + offs_n[:, None] * stride_s + offs_d[None, :] * stride_d + k_ptrs = K + h * sqh + offs_n[:, None] * sqs + offs_d[None, :] * sqd k = tl.load(k_ptrs, mask=k_mask[:, None], other=0.0) # fp16 qk = tl.dot(q, tl.trans(k)).to(tl.float32) * scale # fp16 Tensor Core → fp32 k_pos = offs_n - seq_start @@ -79,13 +79,13 @@ if _HAS_TRITON: p = tl.exp(qk - m_new[:, None]) alpha = tl.exp(m_i - m_new) l_i = l_i * alpha + tl.sum(p, 1) - v_ptrs = V + h * stride_h + offs_n[:, None] * stride_s + offs_d[None, :] * stride_d + v_ptrs = V + h * sqh + offs_n[:, None] * sqs + offs_d[None, :] * sqd v = tl.load(v_ptrs, mask=k_mask[:, None], other=0.0) # fp16 acc = acc * alpha[:, None] + tl.dot(p.to(tl.float16), v) # fp16 Tensor Core → fp32 m_i = m_new acc = acc / l_i[:, None] - o_ptrs = Out + h * stride_h + offs_m[:, None] * stride_s + offs_d[None, :] * stride_d + o_ptrs = Out + h * soh + offs_m[:, None] * sos + offs_d[None, :] * sod tl.store(o_ptrs, acc.to(tl.float16), mask=q_mask[:, None]) @@ -107,13 +107,15 @@ def _triton_varlen_attn(q, k, v, meta): _, H, S, Dh = q.shape cu, blk_seq, blk_inseq, total_blocks = meta BLOCK_M = CONFIG.get("triton_block_m", 64) - out = torch.empty_like(q) - qc = q.contiguous(); kc = k.contiguous(); vc = v.contiguous(); out = out.contiguous() - stride_h, stride_s, stride_d = S * Dh, Dh, 1 + # 不强制 contiguous:kernel 用实际 stride 读非连续的 q/k/v(来自 qkv split+permute)。 + # q,k,v split 同源、stride 相同(k,v 含各自 storage_offset,Triton 用其 data_ptr 自动处理)。 + out = torch.empty((1, H, S, Dh), device=q.device, dtype=torch.float16) + sqh, sqs, sqd = q.stride(1), q.stride(2), q.stride(3) + soh, sos, sod = out.stride(1), out.stride(2), out.stride(3) grid = (total_blocks, H) _varlen_flash_fwd[grid]( - qc, kc, vc, out, cu, blk_seq, blk_inseq, - stride_h, stride_s, stride_d, 1.0 / math.sqrt(Dh), cu.numel() - 1, + q, k, v, out, cu, blk_seq, blk_inseq, + sqh, sqs, sqd, soh, sos, sod, 1.0 / math.sqrt(Dh), cu.numel() - 1, BLOCK_M=BLOCK_M, BLOCK_N=64, D=Dh, ) return out