feat: 分块SDPA注意力(--attn chunked),按用户边界切块降O(S²)

每块~chunk_users个用户、块内因果SDPA(评测端已验证、无嵌套开销),sum(块S²)
远小于总S²。仅1次同步读切分边界。之前本地bs=16快13%被MoE同步吃掉,现MoE
同步已消除,切块红利应全露出。CONFIG.attn=chunked/chunk_users;等价测试已加。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-15 13:13:13 +08:00
parent 1249bbdbbc
commit 3d28f61a98
3 changed files with 60 additions and 5 deletions
+23
View File
@@ -64,6 +64,28 @@ def test_moe_dense_matches_loop():
print(f"[PASS] MoE 稠密向量化 == 逐expert循环 (max err={err:.2e}, dev={dev})")
def test_chunked_matches_dense_attention():
dev = "cuda" if torch.cuda.is_available() else "cpu"
rep = infer.RepEncoder(vocab_size=100, emb_dim=8, slot_num=28, d_model=8)
seq = infer.TransformerEncoder(d_model=8, n_heads=2, num_layers=1, dim_ff=16)
model = infer.CTRModel(rep, seq, d_model=8).to(dev)
torch.manual_seed(0)
H, Dh = 8, 64
offs = _offsets([10, 25, 7, 40, 18, 5, 33], dev) # 7 个用户
S = int(offs[-1])
q = torch.randn(1, H, S, Dh, device=dev)
k = torch.randn(1, H, S, Dh, device=dev)
v = torch.randn(1, H, S, Dh, device=dev)
with torch.no_grad():
dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]})
infer.CONFIG["chunk_users"] = 3 # 每块 3 个用户
chunks = model.build_chunks(offs, torch.device(dev))
chunked = infer.scaled_dot_product(q, k, v, {"chunks": chunks})
err = (dense - chunked).abs().max().item()
assert torch.allclose(dense, chunked, atol=1e-4, rtol=1e-4), f"chunked 不等价 max err={err:.3e}"
print(f"[PASS] chunked SDPA == 稠密SDPA (max err={err:.2e}, dev={dev})")
def test_syncfree_mask_matches():
dev = "cuda" if torch.cuda.is_available() else "cpu"
rep = infer.RepEncoder(vocab_size=100, emb_dim=8, slot_num=28, d_model=8)
@@ -148,6 +170,7 @@ if __name__ == "__main__":
test_moe_dense_matches_loop()
test_fused_embedding_matches_perslot()
test_syncfree_mask_matches()
test_chunked_matches_dense_attention()
test_varlen_matches_dense_attention()
test_flex_matches_dense_attention()
print("[DONE] 等价测试结束")