feat: Triton varlen因果flash attention(块对角,单kernel,消逐块调用+mask构造开销)
每program处理(用户段query块,head),只遍历段内<=该块的key(因果),在线softmax, fp16读写fp32累加。CONFIG.attn=triton(默认仍chunked);_triton_block_meta每batch算一次 block→段映射8层复用;_resolve_attn在无triton/CPU时回退chunked。等价测试+bench --attn triton。 数学等价(FlashAttention同类,规则允许),不改组网。 Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -112,6 +112,28 @@ def test_sparse_pool_matches():
|
||||
print(f"[PASS] sparse_pool == segment_reduce (max err={err:.2e}, dev={dev})")
|
||||
|
||||
|
||||
def test_triton_varlen_matches_dense():
|
||||
if not (torch.cuda.is_available() and infer._HAS_TRITON):
|
||||
print("[SKIP] Triton varlen 等价测试(需 CUDA + triton)")
|
||||
return
|
||||
torch.manual_seed(0)
|
||||
dev = "cuda"
|
||||
H, Dh = 8, 64
|
||||
offs = _offsets([10, 64, 1, 130, 64, 200], dev) # 含跨多块/单token/正好整块的段
|
||||
S = int(offs[-1])
|
||||
q = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
|
||||
k = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
|
||||
v = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
|
||||
with torch.no_grad():
|
||||
dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]})
|
||||
meta = infer._triton_block_meta(offs, 64, q.device)
|
||||
trit = infer.scaled_dot_product(q, k, v, {"triton_meta": meta})
|
||||
err = (dense.float() - trit.float()).abs().max().item()
|
||||
assert torch.allclose(dense.float(), trit.float(), atol=3e-2, rtol=3e-2), \
|
||||
f"Triton varlen 不等价 max err={err:.3e}"
|
||||
print(f"[PASS] Triton varlen flash == 稠密SDPA (max err={err:.2e})")
|
||||
|
||||
|
||||
def test_syncfree_mask_matches():
|
||||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
rep = infer.RepEncoder(vocab_size=100, emb_dim=8, slot_num=28, d_model=8)
|
||||
@@ -197,6 +219,7 @@ if __name__ == "__main__":
|
||||
test_fused_embedding_matches_perslot()
|
||||
test_sparse_pool_matches()
|
||||
test_syncfree_mask_matches()
|
||||
test_triton_varlen_matches_dense()
|
||||
test_chunked_matches_dense_attention()
|
||||
test_varlen_matches_dense_attention()
|
||||
test_flex_matches_dense_attention()
|
||||
|
||||
Reference in New Issue
Block a user