feat: 嵌套张量变长 flash 注意力(--attn varlen),统一 CONFIG.attn 分发
每用户当独立序列、is_causal 块对角因果,一个 flash 内核处理一 batch 内所有
用户,无稠密mask/无padding浪费/开销远低于FlexAttention。CONFIG.attn∈
{sdpa(默认),flex,varlen};bench --attn varlen;test_equiv 加 varlen 等价测试。
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -64,11 +64,32 @@ def test_moe_dense_matches_loop():
|
||||
print(f"[PASS] MoE 稠密向量化 == 逐expert循环 (max err={err:.2e}, dev={dev})")
|
||||
|
||||
|
||||
def test_varlen_matches_dense_attention():
|
||||
if not torch.cuda.is_available():
|
||||
print("[SKIP] varlen 等价测试(需 CUDA)")
|
||||
return
|
||||
torch.manual_seed(0)
|
||||
dev = "cuda"
|
||||
H, Dh = 8, 64
|
||||
offs = _offsets([10, 25, 7, 40, 18], dev)
|
||||
S = int(offs[-1])
|
||||
q = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
|
||||
k = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
|
||||
v = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
|
||||
with torch.no_grad():
|
||||
dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]})
|
||||
varlen = infer.scaled_dot_product(q, k, v, {"varlen_offsets": offs})
|
||||
err = (dense.float() - varlen.float()).abs().max().item()
|
||||
assert torch.allclose(dense.float(), varlen.float(), atol=2e-2, rtol=2e-2), \
|
||||
f"varlen 不等价 max err={err:.3e}"
|
||||
print(f"[PASS] varlen(嵌套张量) == 稠密SDPA (max err={err:.2e})")
|
||||
|
||||
|
||||
def test_flex_matches_dense_attention():
|
||||
ok = (torch.cuda.is_available() and infer._HAS_FLEX
|
||||
and torch.cuda.get_device_capability()[0] >= 8)
|
||||
if not ok:
|
||||
print("[SKIP] FlexAttention 等价测试(需 CUDA SM80+,当前环境不满足)")
|
||||
print("[SKIP] FlexAttention 等价测试(需 CUDA SM80+)")
|
||||
return
|
||||
torch.manual_seed(0)
|
||||
dev = "cuda"
|
||||
@@ -88,5 +109,6 @@ def test_flex_matches_dense_attention():
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_moe_dense_matches_loop()
|
||||
test_varlen_matches_dense_attention()
|
||||
test_flex_matches_dense_attention()
|
||||
print("[DONE] 等价测试结束")
|
||||
|
||||
Reference in New Issue
Block a user