perf: Triton kernel 两个dot改fp16 Tensor Core(flash标准:fp16 matmul+fp32 acc),单块提速2-4x

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-17 00:36:25 +08:00
parent cdc2dd490b
commit 0128fb8100
+5 -5
View File
@@ -58,7 +58,7 @@ if _HAS_TRITON:
offs_d = tl.arange(0, D)
q_mask = offs_m < seq_end
q_ptrs = Q + h * stride_h + offs_m[:, None] * stride_s + offs_d[None, :] * stride_d
q = tl.load(q_ptrs, mask=q_mask[:, None], other=0.0).to(tl.float32)
q = tl.load(q_ptrs, mask=q_mask[:, None], other=0.0) # 保持 fp16dot 走 Tensor Core
m_i = tl.full([BLOCK_M], -float("inf"), tl.float32)
l_i = tl.zeros([BLOCK_M], tl.float32)
@@ -70,8 +70,8 @@ if _HAS_TRITON:
offs_n = kn + tl.arange(0, BLOCK_N)
k_mask = offs_n < seq_end
k_ptrs = K + h * stride_h + offs_n[:, None] * stride_s + offs_d[None, :] * stride_d
k = tl.load(k_ptrs, mask=k_mask[:, None], other=0.0).to(tl.float32)
qk = tl.dot(q, tl.trans(k)) * scale # [BLOCK_M, BLOCK_N]
k = tl.load(k_ptrs, mask=k_mask[:, None], other=0.0) # fp16
qk = tl.dot(q, tl.trans(k)).to(tl.float32) * scale # fp16 Tensor Core → fp32
k_pos = offs_n - seq_start
valid = (q_pos[:, None] >= k_pos[None, :]) & k_mask[None, :]
qk = tl.where(valid, qk, -float("inf"))
@@ -80,8 +80,8 @@ if _HAS_TRITON:
alpha = tl.exp(m_i - m_new)
l_i = l_i * alpha + tl.sum(p, 1)
v_ptrs = V + h * stride_h + offs_n[:, None] * stride_s + offs_d[None, :] * stride_d
v = tl.load(v_ptrs, mask=k_mask[:, None], other=0.0).to(tl.float32)
acc = acc * alpha[:, None] + tl.dot(p, v)
v = tl.load(v_ptrs, mask=k_mask[:, None], other=0.0) # fp16
acc = acc * alpha[:, None] + tl.dot(p.to(tl.float16), v) # fp16 Tensor Core → fp32
m_i = m_new
acc = acc / l_i[:, None]