perf: Triton kernel 两个dot改fp16 Tensor Core(flash标准:fp16 matmul+fp32 acc),单块提速2-4x
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
+5
-5
@@ -58,7 +58,7 @@ if _HAS_TRITON:
|
||||
offs_d = tl.arange(0, D)
|
||||
q_mask = offs_m < seq_end
|
||||
q_ptrs = Q + h * stride_h + offs_m[:, None] * stride_s + offs_d[None, :] * stride_d
|
||||
q = tl.load(q_ptrs, mask=q_mask[:, None], other=0.0).to(tl.float32)
|
||||
q = tl.load(q_ptrs, mask=q_mask[:, None], other=0.0) # 保持 fp16,dot 走 Tensor Core
|
||||
|
||||
m_i = tl.full([BLOCK_M], -float("inf"), tl.float32)
|
||||
l_i = tl.zeros([BLOCK_M], tl.float32)
|
||||
@@ -70,8 +70,8 @@ if _HAS_TRITON:
|
||||
offs_n = kn + tl.arange(0, BLOCK_N)
|
||||
k_mask = offs_n < seq_end
|
||||
k_ptrs = K + h * stride_h + offs_n[:, None] * stride_s + offs_d[None, :] * stride_d
|
||||
k = tl.load(k_ptrs, mask=k_mask[:, None], other=0.0).to(tl.float32)
|
||||
qk = tl.dot(q, tl.trans(k)) * scale # [BLOCK_M, BLOCK_N]
|
||||
k = tl.load(k_ptrs, mask=k_mask[:, None], other=0.0) # fp16
|
||||
qk = tl.dot(q, tl.trans(k)).to(tl.float32) * scale # fp16 Tensor Core → fp32
|
||||
k_pos = offs_n - seq_start
|
||||
valid = (q_pos[:, None] >= k_pos[None, :]) & k_mask[None, :]
|
||||
qk = tl.where(valid, qk, -float("inf"))
|
||||
@@ -80,8 +80,8 @@ if _HAS_TRITON:
|
||||
alpha = tl.exp(m_i - m_new)
|
||||
l_i = l_i * alpha + tl.sum(p, 1)
|
||||
v_ptrs = V + h * stride_h + offs_n[:, None] * stride_s + offs_d[None, :] * stride_d
|
||||
v = tl.load(v_ptrs, mask=k_mask[:, None], other=0.0).to(tl.float32)
|
||||
acc = acc * alpha[:, None] + tl.dot(p, v)
|
||||
v = tl.load(v_ptrs, mask=k_mask[:, None], other=0.0) # fp16
|
||||
acc = acc * alpha[:, None] + tl.dot(p.to(tl.float16), v) # fp16 Tensor Core → fp32
|
||||
m_i = m_new
|
||||
|
||||
acc = acc / l_i[:, None]
|
||||
|
||||
Reference in New Issue
Block a user