perf: triton wrapper 去掉 q/k/v.contiguous(),用实际stride读非连续(省13% clone开销)

profile显示triton的.contiguous()产生492次clone占13%。kernel本就用stride参数,
传q.stride()+out.stride()直接读split+permute后的非连续qkv,免clone。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-17 13:44:10 +08:00
parent 74bb95a7bd
commit 6114c78354
+12 -10
View File
@@ -42,7 +42,7 @@ if _HAS_TRITON:
def _varlen_flash_fwd(
Q, K, V, Out,
cu_seqlens, blk_seq, blk_inseq,
stride_h, stride_s, stride_d,
sqh, sqs, sqd, soh, sos, sod,
scale, n_seq,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, D: tl.constexpr,
):
@@ -57,7 +57,7 @@ if _HAS_TRITON:
offs_m = q_row0 + tl.arange(0, BLOCK_M) # query token 全局行号
offs_d = tl.arange(0, D)
q_mask = offs_m < seq_end
q_ptrs = Q + h * stride_h + offs_m[:, None] * stride_s + offs_d[None, :] * stride_d
q_ptrs = Q + h * sqh + offs_m[:, None] * sqs + offs_d[None, :] * sqd
q = tl.load(q_ptrs, mask=q_mask[:, None], other=0.0) # 保持 fp16dot 走 Tensor Core
m_i = tl.full([BLOCK_M], -float("inf"), tl.float32)
@@ -69,7 +69,7 @@ if _HAS_TRITON:
for kn in range(seq_start, kv_end, BLOCK_N):
offs_n = kn + tl.arange(0, BLOCK_N)
k_mask = offs_n < seq_end
k_ptrs = K + h * stride_h + offs_n[:, None] * stride_s + offs_d[None, :] * stride_d
k_ptrs = K + h * sqh + offs_n[:, None] * sqs + offs_d[None, :] * sqd
k = tl.load(k_ptrs, mask=k_mask[:, None], other=0.0) # fp16
qk = tl.dot(q, tl.trans(k)).to(tl.float32) * scale # fp16 Tensor Core → fp32
k_pos = offs_n - seq_start
@@ -79,13 +79,13 @@ if _HAS_TRITON:
p = tl.exp(qk - m_new[:, None])
alpha = tl.exp(m_i - m_new)
l_i = l_i * alpha + tl.sum(p, 1)
v_ptrs = V + h * stride_h + offs_n[:, None] * stride_s + offs_d[None, :] * stride_d
v_ptrs = V + h * sqh + offs_n[:, None] * sqs + offs_d[None, :] * sqd
v = tl.load(v_ptrs, mask=k_mask[:, None], other=0.0) # fp16
acc = acc * alpha[:, None] + tl.dot(p.to(tl.float16), v) # fp16 Tensor Core → fp32
m_i = m_new
acc = acc / l_i[:, None]
o_ptrs = Out + h * stride_h + offs_m[:, None] * stride_s + offs_d[None, :] * stride_d
o_ptrs = Out + h * soh + offs_m[:, None] * sos + offs_d[None, :] * sod
tl.store(o_ptrs, acc.to(tl.float16), mask=q_mask[:, None])
@@ -107,13 +107,15 @@ def _triton_varlen_attn(q, k, v, meta):
_, H, S, Dh = q.shape
cu, blk_seq, blk_inseq, total_blocks = meta
BLOCK_M = CONFIG.get("triton_block_m", 64)
out = torch.empty_like(q)
qc = q.contiguous(); kc = k.contiguous(); vc = v.contiguous(); out = out.contiguous()
stride_h, stride_s, stride_d = S * Dh, Dh, 1
# 不强制 contiguouskernel 用实际 stride 读非连续的 q/k/v(来自 qkv split+permute)。
# q,k,v split 同源、stride 相同(k,v 含各自 storage_offsetTriton 用其 data_ptr 自动处理)。
out = torch.empty((1, H, S, Dh), device=q.device, dtype=torch.float16)
sqh, sqs, sqd = q.stride(1), q.stride(2), q.stride(3)
soh, sos, sod = out.stride(1), out.stride(2), out.stride(3)
grid = (total_blocks, H)
_varlen_flash_fwd[grid](
qc, kc, vc, out, cu, blk_seq, blk_inseq,
stride_h, stride_s, stride_d, 1.0 / math.sqrt(Dh), cu.numel() - 1,
q, k, v, out, cu, blk_seq, blk_inseq,
sqh, sqs, sqd, soh, sos, sod, 1.0 / math.sqrt(Dh), cu.numel() - 1,
BLOCK_M=BLOCK_M, BLOCK_N=64, D=Dh,
)
return out