From 0128fb8100df2ee9648bc092ba1e3b96abd2dc38 Mon Sep 17 00:00:00 2001 From: OwnerSunshine530 Date: Wed, 17 Jun 2026 00:36:25 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20Triton=20kernel=20=E4=B8=A4=E4=B8=AAdot?= =?UTF-8?q?=E6=94=B9fp16=20Tensor=20Core(flash=E6=A0=87=E5=87=86:fp16=20ma?= =?UTF-8?q?tmul+fp32=20acc),=E5=8D=95=E5=9D=97=E6=8F=90=E9=80=9F2-4x?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Opus 4.8 --- 代码/code/infer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/代码/code/infer.py b/代码/code/infer.py index 394d3b7..aa7f442 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -58,7 +58,7 @@ if _HAS_TRITON: offs_d = tl.arange(0, D) q_mask = offs_m < seq_end q_ptrs = Q + h * stride_h + offs_m[:, None] * stride_s + offs_d[None, :] * stride_d - q = tl.load(q_ptrs, mask=q_mask[:, None], other=0.0).to(tl.float32) + q = tl.load(q_ptrs, mask=q_mask[:, None], other=0.0) # 保持 fp16,dot 走 Tensor Core m_i = tl.full([BLOCK_M], -float("inf"), tl.float32) l_i = tl.zeros([BLOCK_M], tl.float32) @@ -70,8 +70,8 @@ if _HAS_TRITON: offs_n = kn + tl.arange(0, BLOCK_N) k_mask = offs_n < seq_end k_ptrs = K + h * stride_h + offs_n[:, None] * stride_s + offs_d[None, :] * stride_d - k = tl.load(k_ptrs, mask=k_mask[:, None], other=0.0).to(tl.float32) - qk = tl.dot(q, tl.trans(k)) * scale # [BLOCK_M, BLOCK_N] + k = tl.load(k_ptrs, mask=k_mask[:, None], other=0.0) # fp16 + qk = tl.dot(q, tl.trans(k)).to(tl.float32) * scale # fp16 Tensor Core → fp32 k_pos = offs_n - seq_start valid = (q_pos[:, None] >= k_pos[None, :]) & k_mask[None, :] qk = tl.where(valid, qk, -float("inf")) @@ -80,8 +80,8 @@ if _HAS_TRITON: alpha = tl.exp(m_i - m_new) l_i = l_i * alpha + tl.sum(p, 1) v_ptrs = V + h * stride_h + offs_n[:, None] * stride_s + offs_d[None, :] * stride_d - v = tl.load(v_ptrs, mask=k_mask[:, None], other=0.0).to(tl.float32) - acc = acc * alpha[:, None] + tl.dot(p, v) + v = tl.load(v_ptrs, mask=k_mask[:, None], other=0.0) # fp16 + acc = acc * alpha[:, None] + tl.dot(p.to(tl.float16), v) # fp16 Tensor Core → fp32 m_i = m_new acc = acc / l_i[:, None]