diff --git a/代码/code/infer.py b/代码/code/infer.py index 394d3b7..aa7f442 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -58,7 +58,7 @@ if _HAS_TRITON: offs_d = tl.arange(0, D) q_mask = offs_m < seq_end q_ptrs = Q + h * stride_h + offs_m[:, None] * stride_s + offs_d[None, :] * stride_d - q = tl.load(q_ptrs, mask=q_mask[:, None], other=0.0).to(tl.float32) + q = tl.load(q_ptrs, mask=q_mask[:, None], other=0.0) # 保持 fp16,dot 走 Tensor Core m_i = tl.full([BLOCK_M], -float("inf"), tl.float32) l_i = tl.zeros([BLOCK_M], tl.float32) @@ -70,8 +70,8 @@ if _HAS_TRITON: offs_n = kn + tl.arange(0, BLOCK_N) k_mask = offs_n < seq_end k_ptrs = K + h * stride_h + offs_n[:, None] * stride_s + offs_d[None, :] * stride_d - k = tl.load(k_ptrs, mask=k_mask[:, None], other=0.0).to(tl.float32) - qk = tl.dot(q, tl.trans(k)) * scale # [BLOCK_M, BLOCK_N] + k = tl.load(k_ptrs, mask=k_mask[:, None], other=0.0) # fp16 + qk = tl.dot(q, tl.trans(k)).to(tl.float32) * scale # fp16 Tensor Core → fp32 k_pos = offs_n - seq_start valid = (q_pos[:, None] >= k_pos[None, :]) & k_mask[None, :] qk = tl.where(valid, qk, -float("inf")) @@ -80,8 +80,8 @@ if _HAS_TRITON: alpha = tl.exp(m_i - m_new) l_i = l_i * alpha + tl.sum(p, 1) v_ptrs = V + h * stride_h + offs_n[:, None] * stride_s + offs_d[None, :] * stride_d - v = tl.load(v_ptrs, mask=k_mask[:, None], other=0.0).to(tl.float32) - acc = acc * alpha[:, None] + tl.dot(p, v) + v = tl.load(v_ptrs, mask=k_mask[:, None], other=0.0) # fp16 + acc = acc * alpha[:, None] + tl.dot(p.to(tl.float16), v) # fp16 Tensor Core → fp32 m_i = m_new acc = acc / l_i[:, None]