cc4acca875
collate(不计时)把段内重复sign折叠成(唯一,次数),embedding_bag用per_sample_weights=次数。 slot19等高重复段读量大降。攻最大块(embedding_bag 37%带宽)。走已验证的slot key通路(非新key)。 等价测试+bench --collate-dedup。默认关待验证。 Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
303 lines
13 KiB
Python
303 lines
13 KiB
Python
"""Phase B 数值等价测试:新实现 vs 原实现。子进程跑:
|
||
|
||
%cd /home/aistudio/code
|
||
!python tests/test_equiv.py
|
||
|
||
- MoE 稠密向量化 vs 原逐 expert 循环(CPU/GPU 都可,FP32)
|
||
- FlexAttention 块对角因果 vs 稠密 SDPA(需 CUDA SM80+,否则自动跳过)
|
||
"""
|
||
import os
|
||
import sys
|
||
|
||
# baseline 把依赖装在 --target 目录;import 前补 sys.path
|
||
for _p in ("/home/aistudio/external-libraries", "/home/aistudio/libraries",
|
||
os.path.abspath("../libraries"), os.path.abspath("./libraries")):
|
||
if os.path.isdir(_p) and _p not in sys.path:
|
||
sys.path.insert(0, _p)
|
||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||
|
||
import torch
|
||
import torch.nn.functional as F
|
||
import infer
|
||
|
||
|
||
def _offsets(lengths, device):
|
||
offs = [0]
|
||
for L in lengths:
|
||
offs.append(offs[-1] + L)
|
||
return torch.tensor(offs, dtype=torch.long, device=device)
|
||
|
||
|
||
def _dense_causal_mask(offs):
|
||
"""同用户 + 因果(tril),与 CTRModel.get_sequence_causal_mask 语义一致。"""
|
||
lengths = (offs[1:] - offs[:-1]).view(-1)
|
||
idx = torch.repeat_interleave(
|
||
torch.arange(lengths.numel(), device=offs.device), lengths)
|
||
same = idx.view(1, -1) == idx.view(-1, 1)
|
||
causal = torch.tril(torch.ones_like(same, dtype=torch.bool))
|
||
return same & causal
|
||
|
||
|
||
def _block_mask(offs, S):
|
||
lengths = (offs[1:] - offs[:-1]).view(-1)
|
||
doc_id = torch.repeat_interleave(
|
||
torch.arange(lengths.numel(), device=offs.device), lengths)
|
||
|
||
def mask_mod(b, h, q_idx, kv_idx):
|
||
return (q_idx >= kv_idx) & (doc_id[q_idx] == doc_id[kv_idx])
|
||
|
||
return infer.create_block_mask(mask_mod, B=None, H=None, Q_LEN=S, KV_LEN=S,
|
||
device=offs.device)
|
||
|
||
|
||
def test_moe_dense_matches_loop():
|
||
torch.manual_seed(0)
|
||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||
moe = infer.SMoE(d_model=512, dim_ff=1024, num_experts=8, k=2).to(dev).eval()
|
||
x = torch.randn(1, 200, 512, device=dev)
|
||
with torch.no_grad():
|
||
ref, _ = infer._smoe_forward_loop(moe, x)
|
||
infer.CONFIG["vectorize_moe"] = True
|
||
new, _ = moe(x)
|
||
err = (ref - new).abs().max().item()
|
||
assert torch.allclose(ref, new, atol=1e-4, rtol=1e-4), f"MoE 不等价 max err={err:.3e}"
|
||
print(f"[PASS] MoE 稠密向量化 == 逐expert循环 (max err={err:.2e}, dev={dev})")
|
||
|
||
|
||
def test_chunked_matches_dense_attention():
|
||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||
rep = infer.RepEncoder(vocab_size=100, emb_dim=8, slot_num=28, d_model=8)
|
||
seq = infer.TransformerEncoder(d_model=8, n_heads=2, num_layers=1, dim_ff=16)
|
||
model = infer.CTRModel(rep, seq, d_model=8).to(dev)
|
||
torch.manual_seed(0)
|
||
H, Dh = 8, 64
|
||
offs = _offsets([10, 25, 7, 40, 18, 5, 33], dev) # 7 个用户
|
||
S = int(offs[-1])
|
||
q = torch.randn(1, H, S, Dh, device=dev)
|
||
k = torch.randn(1, H, S, Dh, device=dev)
|
||
v = torch.randn(1, H, S, Dh, device=dev)
|
||
with torch.no_grad():
|
||
dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]})
|
||
infer.CONFIG["chunk_users"] = 3 # 每块 3 个用户
|
||
chunks = model.build_chunks(offs, torch.device(dev))
|
||
chunked = infer.scaled_dot_product(q, k, v, {"chunks": chunks})
|
||
err = (dense - chunked).abs().max().item()
|
||
assert torch.allclose(dense, chunked, atol=1e-4, rtol=1e-4), f"chunked 不等价 max err={err:.3e}"
|
||
print(f"[PASS] chunked SDPA == 稠密SDPA (max err={err:.2e}, dev={dev})")
|
||
|
||
|
||
def test_collate_dedup_matches():
|
||
import numpy as _np
|
||
torch.manual_seed(0)
|
||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||
enc = infer.RepEncoder(vocab_size=200, emb_dim=512, slot_num=28, d_model=512).to(dev).eval()
|
||
N = 5
|
||
plain, dedup = {}, {}
|
||
for s in range(1, 29):
|
||
seg_vals, offs_p = [], [0]
|
||
u_vals, u_w, offs_d = [], [], [0]
|
||
for _ in range(N):
|
||
m = int(torch.randint(1, 8, (1,)))
|
||
signs = torch.randint(0, 200, (m,)).tolist()
|
||
signs = signs + signs[:max(0, m - 1)] # 制造段内重复
|
||
seg_vals.extend(signs); offs_p.append(len(seg_vals))
|
||
uq, ct = _np.unique(_np.asarray(signs), return_counts=True)
|
||
u_vals.extend(uq.tolist()); u_w.extend(ct.tolist()); offs_d.append(len(u_vals))
|
||
plain[s] = (torch.tensor(seg_vals, device=dev), torch.tensor(offs_p, device=dev))
|
||
dedup[s] = (torch.tensor(u_vals, device=dev), torch.tensor(offs_d, device=dev),
|
||
torch.tensor(u_w, dtype=torch.float32, device=dev))
|
||
with torch.no_grad():
|
||
infer.CONFIG["use_embedding_bag"] = True
|
||
ref = enc(plain)
|
||
new = enc(dedup)
|
||
infer.CONFIG["use_embedding_bag"] = False
|
||
err = (ref - new).abs().max().item()
|
||
assert torch.allclose(ref, new, atol=1e-3, rtol=1e-3), f"collate_dedup 不等价 max err={err:.3e}"
|
||
print(f"[PASS] collate_dedup(去重+计数) == 全展开 (max err={err:.2e}, dev={dev})")
|
||
|
||
|
||
def test_embedding_bag_matches():
|
||
torch.manual_seed(0)
|
||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||
slot_num, emb_dim, d_model = 28, 512, 512
|
||
enc = infer.RepEncoder(vocab_size=200, emb_dim=emb_dim, slot_num=slot_num,
|
||
d_model=d_model).to(dev).eval()
|
||
N = 6
|
||
batch = {}
|
||
for s in range(1, slot_num + 1):
|
||
counts = torch.randint(0, 8, (N,))
|
||
vals = torch.randint(0, 200, (int(counts.sum()),), device=dev)
|
||
offs = torch.cat([torch.zeros(1, dtype=torch.long), counts.cumsum(0)]).to(dev)
|
||
batch[s] = (vals, offs)
|
||
with torch.no_grad():
|
||
infer.CONFIG["use_embedding_bag"] = False
|
||
ref = enc(batch)
|
||
infer.CONFIG["use_embedding_bag"] = True
|
||
new = enc(batch)
|
||
infer.CONFIG["use_embedding_bag"] = False
|
||
err = (ref - new).abs().max().item()
|
||
assert torch.allclose(ref, new, atol=1e-3, rtol=1e-3), f"embedding_bag 不等价 max err={err:.3e}"
|
||
print(f"[PASS] embedding_bag == segment_reduce (max err={err:.2e}, dev={dev})")
|
||
|
||
|
||
def test_sparse_pool_matches():
|
||
torch.manual_seed(0)
|
||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||
slot_num, emb_dim, d_model = 28, 512, 512
|
||
enc = infer.RepEncoder(vocab_size=200, emb_dim=emb_dim, slot_num=slot_num,
|
||
d_model=d_model).to(dev).eval()
|
||
N = 6
|
||
batch = {}
|
||
for s in range(1, slot_num + 1):
|
||
counts = torch.randint(0, 8, (N,))
|
||
# 故意制造段内重复:值域很小,重复率高
|
||
vals = torch.randint(0, 30, (int(counts.sum()),), device=dev)
|
||
offs = torch.cat([torch.zeros(1, dtype=torch.long), counts.cumsum(0)]).to(dev)
|
||
batch[s] = (vals, offs)
|
||
with torch.no_grad():
|
||
infer.CONFIG["sparse_pool"] = False
|
||
infer.CONFIG["dedup_embedding"] = True
|
||
ref = enc(batch)
|
||
infer.CONFIG["sparse_pool"] = True
|
||
new = enc(batch)
|
||
infer.CONFIG["sparse_pool"] = False
|
||
err = (ref - new).abs().max().item()
|
||
assert torch.allclose(ref, new, atol=2e-2, rtol=2e-2), f"sparse_pool 不等价 max err={err:.3e}"
|
||
print(f"[PASS] sparse_pool == segment_reduce (max err={err:.2e}, dev={dev})")
|
||
|
||
|
||
def test_triton_varlen_matches_dense():
|
||
if not (torch.cuda.is_available() and infer._HAS_TRITON):
|
||
print("[SKIP] Triton varlen 等价测试(需 CUDA + triton)")
|
||
return
|
||
torch.manual_seed(0)
|
||
dev = "cuda"
|
||
H, Dh = 8, 64
|
||
offs = _offsets([10, 64, 1, 130, 64, 200], dev) # 含跨多块/单token/正好整块的段
|
||
S = int(offs[-1])
|
||
q = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
|
||
k = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
|
||
v = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
|
||
with torch.no_grad():
|
||
dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]})
|
||
meta = infer._triton_block_meta(offs, 64, q.device, S)
|
||
trit = infer.scaled_dot_product(q, k, v, {"triton_meta": meta})
|
||
err = (dense.float() - trit.float()).abs().max().item()
|
||
assert torch.allclose(dense.float(), trit.float(), atol=3e-2, rtol=3e-2), \
|
||
f"Triton varlen 不等价 max err={err:.3e}"
|
||
print(f"[PASS] Triton varlen flash == 稠密SDPA (max err={err:.2e})")
|
||
|
||
|
||
def test_syncfree_mask_matches():
|
||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||
rep = infer.RepEncoder(vocab_size=100, emb_dim=8, slot_num=28, d_model=8)
|
||
seq = infer.TransformerEncoder(d_model=8, n_heads=2, num_layers=1, dim_ff=16)
|
||
model = infer.CTRModel(rep, seq, d_model=8).to(dev)
|
||
offs = torch.tensor([0, 10, 35, 42, 60], device=dev) # 4 个用户,变长
|
||
S = int(offs[-1])
|
||
m1 = model.get_sequence_causal_mask(offs)
|
||
m2 = model.causal_mask_syncfree(offs, S, torch.device(dev))
|
||
assert torch.equal(m1, m2), "sync-free mask 与原 mask 不一致"
|
||
print(f"[PASS] searchsorted mask == repeat_interleave mask (dev={dev})")
|
||
|
||
|
||
def test_varlen_matches_dense_attention():
|
||
if not torch.cuda.is_available():
|
||
print("[SKIP] varlen 等价测试(需 CUDA)")
|
||
return
|
||
torch.manual_seed(0)
|
||
dev = "cuda"
|
||
H, Dh = 8, 64
|
||
offs = _offsets([10, 25, 7, 40, 18], dev)
|
||
S = int(offs[-1])
|
||
q = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
|
||
k = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
|
||
v = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
|
||
with torch.no_grad():
|
||
dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]})
|
||
varlen = infer.scaled_dot_product(q, k, v, {"varlen_offsets": offs})
|
||
err = (dense.float() - varlen.float()).abs().max().item()
|
||
assert torch.allclose(dense.float(), varlen.float(), atol=2e-2, rtol=2e-2), \
|
||
f"varlen 不等价 max err={err:.3e}"
|
||
print(f"[PASS] varlen(嵌套张量) == 稠密SDPA (max err={err:.2e})")
|
||
|
||
|
||
def test_sparse_moe_matches_dense():
|
||
# 大 capacity(无丢弃)下,稀疏 MoE 应与 dense 数学等价
|
||
torch.manual_seed(0)
|
||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||
m = infer.SMoE(d_model=512, dim_ff=1024, num_experts=8, k=2).to(dev).eval()
|
||
x = torch.randn(1, 200, 512, device=dev)
|
||
with torch.no_grad():
|
||
infer.CONFIG["moe_sparse"] = False
|
||
ref, _ = m(x)
|
||
infer.CONFIG["moe_sparse"] = True
|
||
infer.CONFIG["moe_capacity"] = 8.0 # 足够大,不丢 token
|
||
new, _ = m(x)
|
||
infer.CONFIG["moe_sparse"] = False
|
||
infer.CONFIG["moe_capacity"] = 1.25
|
||
err = (ref - new).abs().max().item()
|
||
assert torch.allclose(ref, new, atol=1e-3, rtol=1e-3), f"sparse MoE 不等价 max err={err:.3e}"
|
||
print(f"[PASS] sparse MoE(大capacity) == dense (max err={err:.2e}, dev={dev})")
|
||
|
||
|
||
def test_fused_embedding_matches_perslot():
|
||
torch.manual_seed(0)
|
||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||
slot_num, emb_dim, d_model = 28, 512, 512
|
||
enc = infer.RepEncoder(vocab_size=10000, emb_dim=emb_dim, slot_num=slot_num,
|
||
d_model=d_model).to(dev).eval()
|
||
# 造一个 N=6 样本的 batch:每 slot 每样本 0~4 个 sign(含空 slot 边界)
|
||
N = 6
|
||
batch = {}
|
||
for s in range(1, slot_num + 1):
|
||
counts = torch.randint(0, 5, (N,))
|
||
vals = torch.randint(0, 10000, (int(counts.sum()),), device=dev)
|
||
offs = torch.cat([torch.zeros(1, dtype=torch.long), counts.cumsum(0)]).to(dev)
|
||
batch[s] = (vals, offs)
|
||
with torch.no_grad():
|
||
infer.CONFIG["fuse_embedding"] = False
|
||
ref = enc(batch)
|
||
infer.CONFIG["fuse_embedding"] = True
|
||
new = enc(batch)
|
||
err = (ref - new).abs().max().item()
|
||
assert torch.allclose(ref, new, atol=1e-4, rtol=1e-4), f"embedding融合不等价 max err={err:.3e}"
|
||
print(f"[PASS] embedding 融合 == 逐slot (max err={err:.2e}, dev={dev})")
|
||
|
||
|
||
def test_flex_matches_dense_attention():
|
||
ok = (torch.cuda.is_available() and infer._HAS_FLEX
|
||
and torch.cuda.get_device_capability()[0] >= 8)
|
||
if not ok:
|
||
print("[SKIP] FlexAttention 等价测试(需 CUDA SM80+)")
|
||
return
|
||
torch.manual_seed(0)
|
||
dev = "cuda"
|
||
H, Dh = 8, 64
|
||
offs = _offsets([10, 25, 7, 40, 18], dev)
|
||
S = int(offs[-1])
|
||
q = torch.randn(1, H, S, Dh, device=dev)
|
||
k = torch.randn(1, H, S, Dh, device=dev)
|
||
v = torch.randn(1, H, S, Dh, device=dev)
|
||
with torch.no_grad():
|
||
dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]})
|
||
flex = infer.scaled_dot_product(q, k, v, {"block_mask": _block_mask(offs, S)})
|
||
err = (dense - flex).abs().max().item()
|
||
assert torch.allclose(dense, flex, atol=2e-2, rtol=2e-2), f"Flex 不等价 max err={err:.3e}"
|
||
print(f"[PASS] FlexAttention 块对角 == 稠密SDPA (max err={err:.2e})")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
test_moe_dense_matches_loop()
|
||
test_sparse_moe_matches_dense()
|
||
test_fused_embedding_matches_perslot()
|
||
test_embedding_bag_matches()
|
||
test_collate_dedup_matches()
|
||
test_sparse_pool_matches()
|
||
test_syncfree_mask_matches()
|
||
test_triton_varlen_matches_dense()
|
||
test_chunked_matches_dense_attention()
|
||
test_varlen_matches_dense_attention()
|
||
test_flex_matches_dense_attention()
|
||
print("[DONE] 等价测试结束")
|