perf: triton wrapper 去掉 q/k/v.contiguous(),用实际stride读非连续(省13% clone开销)
profile显示triton的.contiguous()产生492次clone占13%。kernel本就用stride参数, 传q.stride()+out.stride()直接读split+permute后的非连续qkv,免clone。 Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
+12
-10
@@ -42,7 +42,7 @@ if _HAS_TRITON:
|
|||||||
def _varlen_flash_fwd(
|
def _varlen_flash_fwd(
|
||||||
Q, K, V, Out,
|
Q, K, V, Out,
|
||||||
cu_seqlens, blk_seq, blk_inseq,
|
cu_seqlens, blk_seq, blk_inseq,
|
||||||
stride_h, stride_s, stride_d,
|
sqh, sqs, sqd, soh, sos, sod,
|
||||||
scale, n_seq,
|
scale, n_seq,
|
||||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, D: tl.constexpr,
|
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, D: tl.constexpr,
|
||||||
):
|
):
|
||||||
@@ -57,7 +57,7 @@ if _HAS_TRITON:
|
|||||||
offs_m = q_row0 + tl.arange(0, BLOCK_M) # query token 全局行号
|
offs_m = q_row0 + tl.arange(0, BLOCK_M) # query token 全局行号
|
||||||
offs_d = tl.arange(0, D)
|
offs_d = tl.arange(0, D)
|
||||||
q_mask = offs_m < seq_end
|
q_mask = offs_m < seq_end
|
||||||
q_ptrs = Q + h * stride_h + offs_m[:, None] * stride_s + offs_d[None, :] * stride_d
|
q_ptrs = Q + h * sqh + offs_m[:, None] * sqs + offs_d[None, :] * sqd
|
||||||
q = tl.load(q_ptrs, mask=q_mask[:, None], other=0.0) # 保持 fp16,dot 走 Tensor Core
|
q = tl.load(q_ptrs, mask=q_mask[:, None], other=0.0) # 保持 fp16,dot 走 Tensor Core
|
||||||
|
|
||||||
m_i = tl.full([BLOCK_M], -float("inf"), tl.float32)
|
m_i = tl.full([BLOCK_M], -float("inf"), tl.float32)
|
||||||
@@ -69,7 +69,7 @@ if _HAS_TRITON:
|
|||||||
for kn in range(seq_start, kv_end, BLOCK_N):
|
for kn in range(seq_start, kv_end, BLOCK_N):
|
||||||
offs_n = kn + tl.arange(0, BLOCK_N)
|
offs_n = kn + tl.arange(0, BLOCK_N)
|
||||||
k_mask = offs_n < seq_end
|
k_mask = offs_n < seq_end
|
||||||
k_ptrs = K + h * stride_h + offs_n[:, None] * stride_s + offs_d[None, :] * stride_d
|
k_ptrs = K + h * sqh + offs_n[:, None] * sqs + offs_d[None, :] * sqd
|
||||||
k = tl.load(k_ptrs, mask=k_mask[:, None], other=0.0) # fp16
|
k = tl.load(k_ptrs, mask=k_mask[:, None], other=0.0) # fp16
|
||||||
qk = tl.dot(q, tl.trans(k)).to(tl.float32) * scale # fp16 Tensor Core → fp32
|
qk = tl.dot(q, tl.trans(k)).to(tl.float32) * scale # fp16 Tensor Core → fp32
|
||||||
k_pos = offs_n - seq_start
|
k_pos = offs_n - seq_start
|
||||||
@@ -79,13 +79,13 @@ if _HAS_TRITON:
|
|||||||
p = tl.exp(qk - m_new[:, None])
|
p = tl.exp(qk - m_new[:, None])
|
||||||
alpha = tl.exp(m_i - m_new)
|
alpha = tl.exp(m_i - m_new)
|
||||||
l_i = l_i * alpha + tl.sum(p, 1)
|
l_i = l_i * alpha + tl.sum(p, 1)
|
||||||
v_ptrs = V + h * stride_h + offs_n[:, None] * stride_s + offs_d[None, :] * stride_d
|
v_ptrs = V + h * sqh + offs_n[:, None] * sqs + offs_d[None, :] * sqd
|
||||||
v = tl.load(v_ptrs, mask=k_mask[:, None], other=0.0) # fp16
|
v = tl.load(v_ptrs, mask=k_mask[:, None], other=0.0) # fp16
|
||||||
acc = acc * alpha[:, None] + tl.dot(p.to(tl.float16), v) # fp16 Tensor Core → fp32
|
acc = acc * alpha[:, None] + tl.dot(p.to(tl.float16), v) # fp16 Tensor Core → fp32
|
||||||
m_i = m_new
|
m_i = m_new
|
||||||
|
|
||||||
acc = acc / l_i[:, None]
|
acc = acc / l_i[:, None]
|
||||||
o_ptrs = Out + h * stride_h + offs_m[:, None] * stride_s + offs_d[None, :] * stride_d
|
o_ptrs = Out + h * soh + offs_m[:, None] * sos + offs_d[None, :] * sod
|
||||||
tl.store(o_ptrs, acc.to(tl.float16), mask=q_mask[:, None])
|
tl.store(o_ptrs, acc.to(tl.float16), mask=q_mask[:, None])
|
||||||
|
|
||||||
|
|
||||||
@@ -107,13 +107,15 @@ def _triton_varlen_attn(q, k, v, meta):
|
|||||||
_, H, S, Dh = q.shape
|
_, H, S, Dh = q.shape
|
||||||
cu, blk_seq, blk_inseq, total_blocks = meta
|
cu, blk_seq, blk_inseq, total_blocks = meta
|
||||||
BLOCK_M = CONFIG.get("triton_block_m", 64)
|
BLOCK_M = CONFIG.get("triton_block_m", 64)
|
||||||
out = torch.empty_like(q)
|
# 不强制 contiguous:kernel 用实际 stride 读非连续的 q/k/v(来自 qkv split+permute)。
|
||||||
qc = q.contiguous(); kc = k.contiguous(); vc = v.contiguous(); out = out.contiguous()
|
# q,k,v split 同源、stride 相同(k,v 含各自 storage_offset,Triton 用其 data_ptr 自动处理)。
|
||||||
stride_h, stride_s, stride_d = S * Dh, Dh, 1
|
out = torch.empty((1, H, S, Dh), device=q.device, dtype=torch.float16)
|
||||||
|
sqh, sqs, sqd = q.stride(1), q.stride(2), q.stride(3)
|
||||||
|
soh, sos, sod = out.stride(1), out.stride(2), out.stride(3)
|
||||||
grid = (total_blocks, H)
|
grid = (total_blocks, H)
|
||||||
_varlen_flash_fwd[grid](
|
_varlen_flash_fwd[grid](
|
||||||
qc, kc, vc, out, cu, blk_seq, blk_inseq,
|
q, k, v, out, cu, blk_seq, blk_inseq,
|
||||||
stride_h, stride_s, stride_d, 1.0 / math.sqrt(Dh), cu.numel() - 1,
|
sqh, sqs, sqd, soh, sos, sod, 1.0 / math.sqrt(Dh), cu.numel() - 1,
|
||||||
BLOCK_M=BLOCK_M, BLOCK_N=64, D=Dh,
|
BLOCK_M=BLOCK_M, BLOCK_N=64, D=Dh,
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|||||||
Reference in New Issue
Block a user