feat: 嵌套张量变长 flash 注意力(--attn varlen),统一 CONFIG.attn 分发

每用户当独立序列、is_causal 块对角因果,一个 flash 内核处理一 batch 内所有
用户,无稠密mask/无padding浪费/开销远低于FlexAttention。CONFIG.attn∈
{sdpa(默认),flex,varlen};bench --attn varlen;test_equiv 加 varlen 等价测试。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-15 09:06:11 +08:00
parent 9eaf5f5511
commit 7791674a32
3 changed files with 74 additions and 28 deletions
+23 -1
View File
@@ -64,11 +64,32 @@ def test_moe_dense_matches_loop():
print(f"[PASS] MoE 稠密向量化 == 逐expert循环 (max err={err:.2e}, dev={dev})")
def test_varlen_matches_dense_attention():
if not torch.cuda.is_available():
print("[SKIP] varlen 等价测试(需 CUDA")
return
torch.manual_seed(0)
dev = "cuda"
H, Dh = 8, 64
offs = _offsets([10, 25, 7, 40, 18], dev)
S = int(offs[-1])
q = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
k = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
v = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
with torch.no_grad():
dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]})
varlen = infer.scaled_dot_product(q, k, v, {"varlen_offsets": offs})
err = (dense.float() - varlen.float()).abs().max().item()
assert torch.allclose(dense.float(), varlen.float(), atol=2e-2, rtol=2e-2), \
f"varlen 不等价 max err={err:.3e}"
print(f"[PASS] varlen(嵌套张量) == 稠密SDPA (max err={err:.2e})")
def test_flex_matches_dense_attention():
ok = (torch.cuda.is_available() and infer._HAS_FLEX
and torch.cuda.get_device_capability()[0] >= 8)
if not ok:
print("[SKIP] FlexAttention 等价测试(需 CUDA SM80+,当前环境不满足")
print("[SKIP] FlexAttention 等价测试(需 CUDA SM80+")
return
torch.manual_seed(0)
dev = "cuda"
@@ -88,5 +109,6 @@ def test_flex_matches_dense_attention():
if __name__ == "__main__":
test_moe_dense_matches_loop()
test_varlen_matches_dense_attention()
test_flex_matches_dense_attention()
print("[DONE] 等价测试结束")