feat: sparse_pool 选项 — (段×唯一)稀疏矩阵乘做池化,避免materialize[M,emb]
针对 profile 的 dedup展开(15%)+segment_reduce(6.6%)。段内高重复(slot19)塌缩 为单个带权项。CONFIG.sparse_pool;bench --sparse-pool;等价测试已加。默认关,待验证。 Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -86,6 +86,32 @@ def test_chunked_matches_dense_attention():
|
||||
print(f"[PASS] chunked SDPA == 稠密SDPA (max err={err:.2e}, dev={dev})")
|
||||
|
||||
|
||||
def test_sparse_pool_matches():
|
||||
torch.manual_seed(0)
|
||||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
slot_num, emb_dim, d_model = 28, 512, 512
|
||||
enc = infer.RepEncoder(vocab_size=200, emb_dim=emb_dim, slot_num=slot_num,
|
||||
d_model=d_model).to(dev).eval()
|
||||
N = 6
|
||||
batch = {}
|
||||
for s in range(1, slot_num + 1):
|
||||
counts = torch.randint(0, 8, (N,))
|
||||
# 故意制造段内重复:值域很小,重复率高
|
||||
vals = torch.randint(0, 30, (int(counts.sum()),), device=dev)
|
||||
offs = torch.cat([torch.zeros(1, dtype=torch.long), counts.cumsum(0)]).to(dev)
|
||||
batch[s] = (vals, offs)
|
||||
with torch.no_grad():
|
||||
infer.CONFIG["sparse_pool"] = False
|
||||
infer.CONFIG["dedup_embedding"] = True
|
||||
ref = enc(batch)
|
||||
infer.CONFIG["sparse_pool"] = True
|
||||
new = enc(batch)
|
||||
infer.CONFIG["sparse_pool"] = False
|
||||
err = (ref - new).abs().max().item()
|
||||
assert torch.allclose(ref, new, atol=2e-2, rtol=2e-2), f"sparse_pool 不等价 max err={err:.3e}"
|
||||
print(f"[PASS] sparse_pool == segment_reduce (max err={err:.2e}, dev={dev})")
|
||||
|
||||
|
||||
def test_syncfree_mask_matches():
|
||||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
rep = infer.RepEncoder(vocab_size=100, emb_dim=8, slot_num=28, d_model=8)
|
||||
@@ -169,6 +195,7 @@ def test_flex_matches_dense_attention():
|
||||
if __name__ == "__main__":
|
||||
test_moe_dense_matches_loop()
|
||||
test_fused_embedding_matches_perslot()
|
||||
test_sparse_pool_matches()
|
||||
test_syncfree_mask_matches()
|
||||
test_chunked_matches_dense_attention()
|
||||
test_varlen_matches_dense_attention()
|
||||
|
||||
Reference in New Issue
Block a user