feat: collate段内去重+计数 → embedding_bag per_sample_weights(减查表带宽,数学等价)
collate(不计时)把段内重复sign折叠成(唯一,次数),embedding_bag用per_sample_weights=次数。 slot19等高重复段读量大降。攻最大块(embedding_bag 37%带宽)。走已验证的slot key通路(非新key)。 等价测试+bench --collate-dedup。默认关待验证。 Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -86,6 +86,36 @@ def test_chunked_matches_dense_attention():
|
||||
print(f"[PASS] chunked SDPA == 稠密SDPA (max err={err:.2e}, dev={dev})")
|
||||
|
||||
|
||||
def test_collate_dedup_matches():
|
||||
import numpy as _np
|
||||
torch.manual_seed(0)
|
||||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
enc = infer.RepEncoder(vocab_size=200, emb_dim=512, slot_num=28, d_model=512).to(dev).eval()
|
||||
N = 5
|
||||
plain, dedup = {}, {}
|
||||
for s in range(1, 29):
|
||||
seg_vals, offs_p = [], [0]
|
||||
u_vals, u_w, offs_d = [], [], [0]
|
||||
for _ in range(N):
|
||||
m = int(torch.randint(1, 8, (1,)))
|
||||
signs = torch.randint(0, 200, (m,)).tolist()
|
||||
signs = signs + signs[:max(0, m - 1)] # 制造段内重复
|
||||
seg_vals.extend(signs); offs_p.append(len(seg_vals))
|
||||
uq, ct = _np.unique(_np.asarray(signs), return_counts=True)
|
||||
u_vals.extend(uq.tolist()); u_w.extend(ct.tolist()); offs_d.append(len(u_vals))
|
||||
plain[s] = (torch.tensor(seg_vals, device=dev), torch.tensor(offs_p, device=dev))
|
||||
dedup[s] = (torch.tensor(u_vals, device=dev), torch.tensor(offs_d, device=dev),
|
||||
torch.tensor(u_w, dtype=torch.float32, device=dev))
|
||||
with torch.no_grad():
|
||||
infer.CONFIG["use_embedding_bag"] = True
|
||||
ref = enc(plain)
|
||||
new = enc(dedup)
|
||||
infer.CONFIG["use_embedding_bag"] = False
|
||||
err = (ref - new).abs().max().item()
|
||||
assert torch.allclose(ref, new, atol=1e-3, rtol=1e-3), f"collate_dedup 不等价 max err={err:.3e}"
|
||||
print(f"[PASS] collate_dedup(去重+计数) == 全展开 (max err={err:.2e}, dev={dev})")
|
||||
|
||||
|
||||
def test_embedding_bag_matches():
|
||||
torch.manual_seed(0)
|
||||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
@@ -262,6 +292,7 @@ if __name__ == "__main__":
|
||||
test_sparse_moe_matches_dense()
|
||||
test_fused_embedding_matches_perslot()
|
||||
test_embedding_bag_matches()
|
||||
test_collate_dedup_matches()
|
||||
test_sparse_pool_matches()
|
||||
test_syncfree_mask_matches()
|
||||
test_triton_varlen_matches_dense()
|
||||
|
||||
Reference in New Issue
Block a user