perf: Triton kernel 两个dot改fp16 Tensor Core(flash标准:fp16 matmul+fp32 acc),单块提速2-4x
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
+5
-5
@@ -58,7 +58,7 @@ if _HAS_TRITON:
|
|||||||
offs_d = tl.arange(0, D)
|
offs_d = tl.arange(0, D)
|
||||||
q_mask = offs_m < seq_end
|
q_mask = offs_m < seq_end
|
||||||
q_ptrs = Q + h * stride_h + offs_m[:, None] * stride_s + offs_d[None, :] * stride_d
|
q_ptrs = Q + h * stride_h + offs_m[:, None] * stride_s + offs_d[None, :] * stride_d
|
||||||
q = tl.load(q_ptrs, mask=q_mask[:, None], other=0.0).to(tl.float32)
|
q = tl.load(q_ptrs, mask=q_mask[:, None], other=0.0) # 保持 fp16,dot 走 Tensor Core
|
||||||
|
|
||||||
m_i = tl.full([BLOCK_M], -float("inf"), tl.float32)
|
m_i = tl.full([BLOCK_M], -float("inf"), tl.float32)
|
||||||
l_i = tl.zeros([BLOCK_M], tl.float32)
|
l_i = tl.zeros([BLOCK_M], tl.float32)
|
||||||
@@ -70,8 +70,8 @@ if _HAS_TRITON:
|
|||||||
offs_n = kn + tl.arange(0, BLOCK_N)
|
offs_n = kn + tl.arange(0, BLOCK_N)
|
||||||
k_mask = offs_n < seq_end
|
k_mask = offs_n < seq_end
|
||||||
k_ptrs = K + h * stride_h + offs_n[:, None] * stride_s + offs_d[None, :] * stride_d
|
k_ptrs = K + h * stride_h + offs_n[:, None] * stride_s + offs_d[None, :] * stride_d
|
||||||
k = tl.load(k_ptrs, mask=k_mask[:, None], other=0.0).to(tl.float32)
|
k = tl.load(k_ptrs, mask=k_mask[:, None], other=0.0) # fp16
|
||||||
qk = tl.dot(q, tl.trans(k)) * scale # [BLOCK_M, BLOCK_N]
|
qk = tl.dot(q, tl.trans(k)).to(tl.float32) * scale # fp16 Tensor Core → fp32
|
||||||
k_pos = offs_n - seq_start
|
k_pos = offs_n - seq_start
|
||||||
valid = (q_pos[:, None] >= k_pos[None, :]) & k_mask[None, :]
|
valid = (q_pos[:, None] >= k_pos[None, :]) & k_mask[None, :]
|
||||||
qk = tl.where(valid, qk, -float("inf"))
|
qk = tl.where(valid, qk, -float("inf"))
|
||||||
@@ -80,8 +80,8 @@ if _HAS_TRITON:
|
|||||||
alpha = tl.exp(m_i - m_new)
|
alpha = tl.exp(m_i - m_new)
|
||||||
l_i = l_i * alpha + tl.sum(p, 1)
|
l_i = l_i * alpha + tl.sum(p, 1)
|
||||||
v_ptrs = V + h * stride_h + offs_n[:, None] * stride_s + offs_d[None, :] * stride_d
|
v_ptrs = V + h * stride_h + offs_n[:, None] * stride_s + offs_d[None, :] * stride_d
|
||||||
v = tl.load(v_ptrs, mask=k_mask[:, None], other=0.0).to(tl.float32)
|
v = tl.load(v_ptrs, mask=k_mask[:, None], other=0.0) # fp16
|
||||||
acc = acc * alpha[:, None] + tl.dot(p, v)
|
acc = acc * alpha[:, None] + tl.dot(p.to(tl.float16), v) # fp16 Tensor Core → fp32
|
||||||
m_i = m_new
|
m_i = m_new
|
||||||
|
|
||||||
acc = acc / l_i[:, None]
|
acc = acc / l_i[:, None]
|
||||||
|
|||||||
Reference in New Issue
Block a user