feat: Triton varlen因果flash attention(块对角,单kernel,消逐块调用+mask构造开销)
每program处理(用户段query块,head),只遍历段内<=该块的key(因果),在线softmax, fp16读写fp32累加。CONFIG.attn=triton(默认仍chunked);_triton_block_meta每batch算一次 block→段映射8层复用;_resolve_attn在无triton/CPU时回退chunked。等价测试+bench --attn triton。 数学等价(FlashAttention同类,规则允许),不改组网。 Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
+106
-3
@@ -26,6 +26,98 @@ except Exception:
|
||||
create_block_mask = None
|
||||
_HAS_FLEX = False
|
||||
|
||||
# Triton varlen 因果 flash attention(块对角,单 kernel,消除逐块调用/mask 构造开销)
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
_HAS_TRITON = True
|
||||
except Exception:
|
||||
triton = None
|
||||
tl = None
|
||||
_HAS_TRITON = False
|
||||
|
||||
|
||||
if _HAS_TRITON:
|
||||
@triton.jit
|
||||
def _varlen_flash_fwd(
|
||||
Q, K, V, Out,
|
||||
cu_seqlens, blk_seq, blk_inseq,
|
||||
stride_h, stride_s, stride_d,
|
||||
scale, n_seq,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, D: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0) # 全局 query 块
|
||||
h = tl.program_id(1) # head
|
||||
s = tl.load(blk_seq + pid)
|
||||
bis = tl.load(blk_inseq + pid)
|
||||
seq_start = tl.load(cu_seqlens + s)
|
||||
seq_end = tl.load(cu_seqlens + s + 1)
|
||||
|
||||
q_row0 = seq_start + bis * BLOCK_M
|
||||
offs_m = q_row0 + tl.arange(0, BLOCK_M) # query token 全局行号
|
||||
offs_d = tl.arange(0, D)
|
||||
q_mask = offs_m < seq_end
|
||||
q_ptrs = Q + h * stride_h + offs_m[:, None] * stride_s + offs_d[None, :] * stride_d
|
||||
q = tl.load(q_ptrs, mask=q_mask[:, None], other=0.0).to(tl.float32)
|
||||
|
||||
m_i = tl.full([BLOCK_M], -float("inf"), tl.float32)
|
||||
l_i = tl.zeros([BLOCK_M], tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, D], tl.float32)
|
||||
|
||||
q_pos = offs_m - seq_start # query 段内位置
|
||||
kv_end = q_row0 + BLOCK_M # 因果:key 不超过本 query 块末尾
|
||||
for kn in range(seq_start, kv_end, BLOCK_N):
|
||||
offs_n = kn + tl.arange(0, BLOCK_N)
|
||||
k_mask = offs_n < seq_end
|
||||
k_ptrs = K + h * stride_h + offs_n[:, None] * stride_s + offs_d[None, :] * stride_d
|
||||
k = tl.load(k_ptrs, mask=k_mask[:, None], other=0.0).to(tl.float32)
|
||||
qk = tl.dot(q, tl.trans(k)) * scale # [BLOCK_M, BLOCK_N]
|
||||
k_pos = offs_n - seq_start
|
||||
valid = (q_pos[:, None] >= k_pos[None, :]) & k_mask[None, :]
|
||||
qk = tl.where(valid, qk, -float("inf"))
|
||||
m_new = tl.maximum(m_i, tl.max(qk, 1))
|
||||
p = tl.exp(qk - m_new[:, None])
|
||||
alpha = tl.exp(m_i - m_new)
|
||||
l_i = l_i * alpha + tl.sum(p, 1)
|
||||
v_ptrs = V + h * stride_h + offs_n[:, None] * stride_s + offs_d[None, :] * stride_d
|
||||
v = tl.load(v_ptrs, mask=k_mask[:, None], other=0.0).to(tl.float32)
|
||||
acc = acc * alpha[:, None] + tl.dot(p, v)
|
||||
m_i = m_new
|
||||
|
||||
acc = acc / l_i[:, None]
|
||||
o_ptrs = Out + h * stride_h + offs_m[:, None] * stride_s + offs_d[None, :] * stride_d
|
||||
tl.store(o_ptrs, acc.to(tl.float16), mask=q_mask[:, None])
|
||||
|
||||
|
||||
def _triton_block_meta(user_offsets, BLOCK_M, device):
|
||||
"""从 user_offsets 算 block→段映射(每 batch 一次、8 层复用;含 1 次同步读 total_blocks)。"""
|
||||
cu = user_offsets.to(torch.int32)
|
||||
seqlens = (cu[1:] - cu[:-1]).to(torch.int64)
|
||||
blocks_per = (seqlens + BLOCK_M - 1) // BLOCK_M
|
||||
n_seq = seqlens.numel()
|
||||
blk_seq = torch.repeat_interleave(torch.arange(n_seq, device=device), blocks_per)
|
||||
total_blocks = blk_seq.numel()
|
||||
starts = torch.cumsum(blocks_per, 0) - blocks_per
|
||||
blk_inseq = torch.arange(total_blocks, device=device) - starts[blk_seq]
|
||||
return cu.contiguous(), blk_seq.to(torch.int32).contiguous(), blk_inseq.to(torch.int32).contiguous(), total_blocks
|
||||
|
||||
|
||||
def _triton_varlen_attn(q, k, v, meta):
|
||||
"""q,k,v: [1, H, S, Dh](contiguous)。meta=(cu, blk_seq, blk_inseq, total_blocks)。返回 [1,H,S,Dh]。"""
|
||||
_, H, S, Dh = q.shape
|
||||
cu, blk_seq, blk_inseq, total_blocks = meta
|
||||
BLOCK_M = 64
|
||||
out = torch.empty_like(q)
|
||||
qc = q.contiguous(); kc = k.contiguous(); vc = v.contiguous(); out = out.contiguous()
|
||||
stride_h, stride_s, stride_d = S * Dh, Dh, 1
|
||||
grid = (total_blocks, H)
|
||||
_varlen_flash_fwd[grid](
|
||||
qc, kc, vc, out, cu, blk_seq, blk_inseq,
|
||||
stride_h, stride_s, stride_d, 1.0 / math.sqrt(Dh), cu.numel() - 1,
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=64, D=Dh,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 实验配置开关板
|
||||
@@ -65,12 +157,17 @@ CONFIG = {
|
||||
|
||||
|
||||
def _resolve_attn(device):
|
||||
"""解析实际使用的注意力实现。flex 需 SM80+ 且可用,否则回退 sdpa。"""
|
||||
"""解析实际使用的注意力实现。triton/flex 需 CUDA(SM80+ for flex),否则回退 chunked/sdpa。"""
|
||||
attn = CONFIG.get("attn", "sdpa")
|
||||
is_cuda = device is not None and device.type == "cuda"
|
||||
if attn == "triton":
|
||||
if not (_HAS_TRITON and is_cuda):
|
||||
return "chunked" # Triton 不可用 → 回退已验证的 chunked
|
||||
return "triton"
|
||||
if attn == "flex":
|
||||
if not _HAS_FLEX:
|
||||
return "sdpa"
|
||||
if device is not None and device.type == "cuda":
|
||||
if is_cuda:
|
||||
try:
|
||||
if torch.cuda.get_device_capability(device)[0] < 8:
|
||||
return "sdpa"
|
||||
@@ -475,6 +572,9 @@ def scaled_dot_product(q, k, v, extension):
|
||||
- block_mask → FlexAttention 块对角因果。
|
||||
- mask(默认) → 标准 SDPA 稠密 mask(数学等价、已验证最快)。
|
||||
"""
|
||||
if extension is not None and extension.get("triton_meta") is not None:
|
||||
return _triton_varlen_attn(q, k, v, extension["triton_meta"])
|
||||
|
||||
if extension is not None and extension.get("chunks") is not None:
|
||||
outs = []
|
||||
for s0, s1, m in extension["chunks"]:
|
||||
@@ -734,7 +834,10 @@ class CTRModel(nn.Module):
|
||||
seq_input = self.rep_encoder(batch)
|
||||
user_offsets = batch["user_offsets"]
|
||||
attn = _resolve_attn(seq_input.device)
|
||||
if attn == "chunked":
|
||||
if attn == "triton":
|
||||
meta = _triton_block_meta(user_offsets, 64, seq_input.device)
|
||||
extension = {"triton_meta": meta}
|
||||
elif attn == "chunked":
|
||||
extension = {"chunks": self.build_chunks(user_offsets, seq_input.device)}
|
||||
elif attn == "varlen":
|
||||
extension = {"varlen_offsets": user_offsets}
|
||||
|
||||
Reference in New Issue
Block a user