feat: Triton varlen因果flash attention(块对角,单kernel,消逐块调用+mask构造开销)

每program处理(用户段query块,head),只遍历段内<=该块的key(因果),在线softmax,
fp16读写fp32累加。CONFIG.attn=triton(默认仍chunked);_triton_block_meta每batch算一次
block→段映射8层复用;_resolve_attn在无triton/CPU时回退chunked。等价测试+bench --attn triton。
数学等价(FlashAttention同类,规则允许),不改组网。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-17 00:14:53 +08:00
parent a5ee660523
commit cdc2dd490b
3 changed files with 131 additions and 5 deletions
+2 -2
View File
@@ -315,8 +315,8 @@ def _parse_args():
help="逗号分隔的 keep_fp32_modules,如 linear,rep_encoder.input_norm")
ap.add_argument("--feasign-none", action="store_true",
help="不截断特征(max_feasign_per_slot=None")
ap.add_argument("--attn", choices=["sdpa", "chunked", "flex", "varlen"], default=None,
help="注意力:sdpa=稠密, chunked=按用户分块SDPA, flex/varlen=对照")
ap.add_argument("--attn", choices=["sdpa", "chunked", "triton", "flex", "varlen"], default=None,
help="注意力:sdpa=稠密, chunked=分块SDPA, triton=varlen flash kernel, flex/varlen=对照")
ap.add_argument("--chunk-users", type=int, default=None, help="chunked 模式每块用户数")
ap.add_argument("--moe", choices=["dense", "loop"], default=None,
help="MoE实现:dense=向量化(新), loop=逐expert循环(原)")
+106 -3
View File
@@ -26,6 +26,98 @@ except Exception:
create_block_mask = None
_HAS_FLEX = False
# Triton varlen 因果 flash attention(块对角,单 kernel,消除逐块调用/mask 构造开销)
try:
import triton
import triton.language as tl
_HAS_TRITON = True
except Exception:
triton = None
tl = None
_HAS_TRITON = False
if _HAS_TRITON:
@triton.jit
def _varlen_flash_fwd(
Q, K, V, Out,
cu_seqlens, blk_seq, blk_inseq,
stride_h, stride_s, stride_d,
scale, n_seq,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, D: tl.constexpr,
):
pid = tl.program_id(0) # 全局 query 块
h = tl.program_id(1) # head
s = tl.load(blk_seq + pid)
bis = tl.load(blk_inseq + pid)
seq_start = tl.load(cu_seqlens + s)
seq_end = tl.load(cu_seqlens + s + 1)
q_row0 = seq_start + bis * BLOCK_M
offs_m = q_row0 + tl.arange(0, BLOCK_M) # query token 全局行号
offs_d = tl.arange(0, D)
q_mask = offs_m < seq_end
q_ptrs = Q + h * stride_h + offs_m[:, None] * stride_s + offs_d[None, :] * stride_d
q = tl.load(q_ptrs, mask=q_mask[:, None], other=0.0).to(tl.float32)
m_i = tl.full([BLOCK_M], -float("inf"), tl.float32)
l_i = tl.zeros([BLOCK_M], tl.float32)
acc = tl.zeros([BLOCK_M, D], tl.float32)
q_pos = offs_m - seq_start # query 段内位置
kv_end = q_row0 + BLOCK_M # 因果:key 不超过本 query 块末尾
for kn in range(seq_start, kv_end, BLOCK_N):
offs_n = kn + tl.arange(0, BLOCK_N)
k_mask = offs_n < seq_end
k_ptrs = K + h * stride_h + offs_n[:, None] * stride_s + offs_d[None, :] * stride_d
k = tl.load(k_ptrs, mask=k_mask[:, None], other=0.0).to(tl.float32)
qk = tl.dot(q, tl.trans(k)) * scale # [BLOCK_M, BLOCK_N]
k_pos = offs_n - seq_start
valid = (q_pos[:, None] >= k_pos[None, :]) & k_mask[None, :]
qk = tl.where(valid, qk, -float("inf"))
m_new = tl.maximum(m_i, tl.max(qk, 1))
p = tl.exp(qk - m_new[:, None])
alpha = tl.exp(m_i - m_new)
l_i = l_i * alpha + tl.sum(p, 1)
v_ptrs = V + h * stride_h + offs_n[:, None] * stride_s + offs_d[None, :] * stride_d
v = tl.load(v_ptrs, mask=k_mask[:, None], other=0.0).to(tl.float32)
acc = acc * alpha[:, None] + tl.dot(p, v)
m_i = m_new
acc = acc / l_i[:, None]
o_ptrs = Out + h * stride_h + offs_m[:, None] * stride_s + offs_d[None, :] * stride_d
tl.store(o_ptrs, acc.to(tl.float16), mask=q_mask[:, None])
def _triton_block_meta(user_offsets, BLOCK_M, device):
"""从 user_offsets 算 block→段映射(每 batch 一次、8 层复用;含 1 次同步读 total_blocks)。"""
cu = user_offsets.to(torch.int32)
seqlens = (cu[1:] - cu[:-1]).to(torch.int64)
blocks_per = (seqlens + BLOCK_M - 1) // BLOCK_M
n_seq = seqlens.numel()
blk_seq = torch.repeat_interleave(torch.arange(n_seq, device=device), blocks_per)
total_blocks = blk_seq.numel()
starts = torch.cumsum(blocks_per, 0) - blocks_per
blk_inseq = torch.arange(total_blocks, device=device) - starts[blk_seq]
return cu.contiguous(), blk_seq.to(torch.int32).contiguous(), blk_inseq.to(torch.int32).contiguous(), total_blocks
def _triton_varlen_attn(q, k, v, meta):
"""q,k,v: [1, H, S, Dh]contiguous)。meta=(cu, blk_seq, blk_inseq, total_blocks)。返回 [1,H,S,Dh]。"""
_, H, S, Dh = q.shape
cu, blk_seq, blk_inseq, total_blocks = meta
BLOCK_M = 64
out = torch.empty_like(q)
qc = q.contiguous(); kc = k.contiguous(); vc = v.contiguous(); out = out.contiguous()
stride_h, stride_s, stride_d = S * Dh, Dh, 1
grid = (total_blocks, H)
_varlen_flash_fwd[grid](
qc, kc, vc, out, cu, blk_seq, blk_inseq,
stride_h, stride_s, stride_d, 1.0 / math.sqrt(Dh), cu.numel() - 1,
BLOCK_M=BLOCK_M, BLOCK_N=64, D=Dh,
)
return out
# ============================================================
# 实验配置开关板
@@ -65,12 +157,17 @@ CONFIG = {
def _resolve_attn(device):
"""解析实际使用的注意力实现。flex 需 SM80+ 且可用,否则回退 sdpa。"""
"""解析实际使用的注意力实现。triton/flex 需 CUDA(SM80+ for flex),否则回退 chunked/sdpa。"""
attn = CONFIG.get("attn", "sdpa")
is_cuda = device is not None and device.type == "cuda"
if attn == "triton":
if not (_HAS_TRITON and is_cuda):
return "chunked" # Triton 不可用 → 回退已验证的 chunked
return "triton"
if attn == "flex":
if not _HAS_FLEX:
return "sdpa"
if device is not None and device.type == "cuda":
if is_cuda:
try:
if torch.cuda.get_device_capability(device)[0] < 8:
return "sdpa"
@@ -475,6 +572,9 @@ def scaled_dot_product(q, k, v, extension):
- block_mask → FlexAttention 块对角因果。
- mask(默认) → 标准 SDPA 稠密 mask(数学等价、已验证最快)。
"""
if extension is not None and extension.get("triton_meta") is not None:
return _triton_varlen_attn(q, k, v, extension["triton_meta"])
if extension is not None and extension.get("chunks") is not None:
outs = []
for s0, s1, m in extension["chunks"]:
@@ -734,7 +834,10 @@ class CTRModel(nn.Module):
seq_input = self.rep_encoder(batch)
user_offsets = batch["user_offsets"]
attn = _resolve_attn(seq_input.device)
if attn == "chunked":
if attn == "triton":
meta = _triton_block_meta(user_offsets, 64, seq_input.device)
extension = {"triton_meta": meta}
elif attn == "chunked":
extension = {"chunks": self.build_chunks(user_offsets, seq_input.device)}
elif attn == "varlen":
extension = {"varlen_offsets": user_offsets}
+23
View File
@@ -112,6 +112,28 @@ def test_sparse_pool_matches():
print(f"[PASS] sparse_pool == segment_reduce (max err={err:.2e}, dev={dev})")
def test_triton_varlen_matches_dense():
if not (torch.cuda.is_available() and infer._HAS_TRITON):
print("[SKIP] Triton varlen 等价测试(需 CUDA + triton")
return
torch.manual_seed(0)
dev = "cuda"
H, Dh = 8, 64
offs = _offsets([10, 64, 1, 130, 64, 200], dev) # 含跨多块/单token/正好整块的段
S = int(offs[-1])
q = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
k = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
v = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
with torch.no_grad():
dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]})
meta = infer._triton_block_meta(offs, 64, q.device)
trit = infer.scaled_dot_product(q, k, v, {"triton_meta": meta})
err = (dense.float() - trit.float()).abs().max().item()
assert torch.allclose(dense.float(), trit.float(), atol=3e-2, rtol=3e-2), \
f"Triton varlen 不等价 max err={err:.3e}"
print(f"[PASS] Triton varlen flash == 稠密SDPA (max err={err:.2e})")
def test_syncfree_mask_matches():
dev = "cuda" if torch.cuda.is_available() else "cpu"
rep = infer.RepEncoder(vocab_size=100, emb_dim=8, slot_num=28, d_model=8)
@@ -197,6 +219,7 @@ if __name__ == "__main__":
test_fused_embedding_matches_perslot()
test_sparse_pool_matches()
test_syncfree_mask_matches()
test_triton_varlen_matches_dense()
test_chunked_matches_dense_attention()
test_varlen_matches_dense_attention()
test_flex_matches_dense_attention()