feat: 分块SDPA注意力(--attn chunked),按用户边界切块降O(S²)
每块~chunk_users个用户、块内因果SDPA(评测端已验证、无嵌套开销),sum(块S²) 远小于总S²。仅1次同步读切分边界。之前本地bs=16快13%被MoE同步吃掉,现MoE 同步已消除,切块红利应全露出。CONFIG.attn=chunked/chunk_users;等价测试已加。 Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -64,6 +64,28 @@ def test_moe_dense_matches_loop():
|
||||
print(f"[PASS] MoE 稠密向量化 == 逐expert循环 (max err={err:.2e}, dev={dev})")
|
||||
|
||||
|
||||
def test_chunked_matches_dense_attention():
|
||||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
rep = infer.RepEncoder(vocab_size=100, emb_dim=8, slot_num=28, d_model=8)
|
||||
seq = infer.TransformerEncoder(d_model=8, n_heads=2, num_layers=1, dim_ff=16)
|
||||
model = infer.CTRModel(rep, seq, d_model=8).to(dev)
|
||||
torch.manual_seed(0)
|
||||
H, Dh = 8, 64
|
||||
offs = _offsets([10, 25, 7, 40, 18, 5, 33], dev) # 7 个用户
|
||||
S = int(offs[-1])
|
||||
q = torch.randn(1, H, S, Dh, device=dev)
|
||||
k = torch.randn(1, H, S, Dh, device=dev)
|
||||
v = torch.randn(1, H, S, Dh, device=dev)
|
||||
with torch.no_grad():
|
||||
dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]})
|
||||
infer.CONFIG["chunk_users"] = 3 # 每块 3 个用户
|
||||
chunks = model.build_chunks(offs, torch.device(dev))
|
||||
chunked = infer.scaled_dot_product(q, k, v, {"chunks": chunks})
|
||||
err = (dense - chunked).abs().max().item()
|
||||
assert torch.allclose(dense, chunked, atol=1e-4, rtol=1e-4), f"chunked 不等价 max err={err:.3e}"
|
||||
print(f"[PASS] chunked SDPA == 稠密SDPA (max err={err:.2e}, dev={dev})")
|
||||
|
||||
|
||||
def test_syncfree_mask_matches():
|
||||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
rep = infer.RepEncoder(vocab_size=100, emb_dim=8, slot_num=28, d_model=8)
|
||||
@@ -148,6 +170,7 @@ if __name__ == "__main__":
|
||||
test_moe_dense_matches_loop()
|
||||
test_fused_embedding_matches_perslot()
|
||||
test_syncfree_mask_matches()
|
||||
test_chunked_matches_dense_attention()
|
||||
test_varlen_matches_dense_attention()
|
||||
test_flex_matches_dense_attention()
|
||||
print("[DONE] 等价测试结束")
|
||||
|
||||
Reference in New Issue
Block a user