Files
CTI-Inference-Opt/代码/code/tests/test_equiv.py
T
OwnerSunshine530 cc4acca875 feat: collate段内去重+计数 → embedding_bag per_sample_weights(减查表带宽,数学等价)
collate(不计时)把段内重复sign折叠成(唯一,次数),embedding_bag用per_sample_weights=次数。
slot19等高重复段读量大降。攻最大块(embedding_bag 37%带宽)。走已验证的slot key通路(非新key)。
等价测试+bench --collate-dedup。默认关待验证。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-20 14:46:48 +08:00

303 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Phase B 数值等价测试:新实现 vs 原实现。子进程跑:
%cd /home/aistudio/code
!python tests/test_equiv.py
- MoE 稠密向量化 vs 原逐 expert 循环(CPU/GPU 都可,FP32
- FlexAttention 块对角因果 vs 稠密 SDPA(需 CUDA SM80+,否则自动跳过)
"""
import os
import sys
# baseline 把依赖装在 --target 目录;import 前补 sys.path
for _p in ("/home/aistudio/external-libraries", "/home/aistudio/libraries",
os.path.abspath("../libraries"), os.path.abspath("./libraries")):
if os.path.isdir(_p) and _p not in sys.path:
sys.path.insert(0, _p)
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import torch
import torch.nn.functional as F
import infer
def _offsets(lengths, device):
offs = [0]
for L in lengths:
offs.append(offs[-1] + L)
return torch.tensor(offs, dtype=torch.long, device=device)
def _dense_causal_mask(offs):
"""同用户 + 因果(tril),与 CTRModel.get_sequence_causal_mask 语义一致。"""
lengths = (offs[1:] - offs[:-1]).view(-1)
idx = torch.repeat_interleave(
torch.arange(lengths.numel(), device=offs.device), lengths)
same = idx.view(1, -1) == idx.view(-1, 1)
causal = torch.tril(torch.ones_like(same, dtype=torch.bool))
return same & causal
def _block_mask(offs, S):
lengths = (offs[1:] - offs[:-1]).view(-1)
doc_id = torch.repeat_interleave(
torch.arange(lengths.numel(), device=offs.device), lengths)
def mask_mod(b, h, q_idx, kv_idx):
return (q_idx >= kv_idx) & (doc_id[q_idx] == doc_id[kv_idx])
return infer.create_block_mask(mask_mod, B=None, H=None, Q_LEN=S, KV_LEN=S,
device=offs.device)
def test_moe_dense_matches_loop():
torch.manual_seed(0)
dev = "cuda" if torch.cuda.is_available() else "cpu"
moe = infer.SMoE(d_model=512, dim_ff=1024, num_experts=8, k=2).to(dev).eval()
x = torch.randn(1, 200, 512, device=dev)
with torch.no_grad():
ref, _ = infer._smoe_forward_loop(moe, x)
infer.CONFIG["vectorize_moe"] = True
new, _ = moe(x)
err = (ref - new).abs().max().item()
assert torch.allclose(ref, new, atol=1e-4, rtol=1e-4), f"MoE 不等价 max err={err:.3e}"
print(f"[PASS] MoE 稠密向量化 == 逐expert循环 (max err={err:.2e}, dev={dev})")
def test_chunked_matches_dense_attention():
dev = "cuda" if torch.cuda.is_available() else "cpu"
rep = infer.RepEncoder(vocab_size=100, emb_dim=8, slot_num=28, d_model=8)
seq = infer.TransformerEncoder(d_model=8, n_heads=2, num_layers=1, dim_ff=16)
model = infer.CTRModel(rep, seq, d_model=8).to(dev)
torch.manual_seed(0)
H, Dh = 8, 64
offs = _offsets([10, 25, 7, 40, 18, 5, 33], dev) # 7 个用户
S = int(offs[-1])
q = torch.randn(1, H, S, Dh, device=dev)
k = torch.randn(1, H, S, Dh, device=dev)
v = torch.randn(1, H, S, Dh, device=dev)
with torch.no_grad():
dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]})
infer.CONFIG["chunk_users"] = 3 # 每块 3 个用户
chunks = model.build_chunks(offs, torch.device(dev))
chunked = infer.scaled_dot_product(q, k, v, {"chunks": chunks})
err = (dense - chunked).abs().max().item()
assert torch.allclose(dense, chunked, atol=1e-4, rtol=1e-4), f"chunked 不等价 max err={err:.3e}"
print(f"[PASS] chunked SDPA == 稠密SDPA (max err={err:.2e}, dev={dev})")
def test_collate_dedup_matches():
import numpy as _np
torch.manual_seed(0)
dev = "cuda" if torch.cuda.is_available() else "cpu"
enc = infer.RepEncoder(vocab_size=200, emb_dim=512, slot_num=28, d_model=512).to(dev).eval()
N = 5
plain, dedup = {}, {}
for s in range(1, 29):
seg_vals, offs_p = [], [0]
u_vals, u_w, offs_d = [], [], [0]
for _ in range(N):
m = int(torch.randint(1, 8, (1,)))
signs = torch.randint(0, 200, (m,)).tolist()
signs = signs + signs[:max(0, m - 1)] # 制造段内重复
seg_vals.extend(signs); offs_p.append(len(seg_vals))
uq, ct = _np.unique(_np.asarray(signs), return_counts=True)
u_vals.extend(uq.tolist()); u_w.extend(ct.tolist()); offs_d.append(len(u_vals))
plain[s] = (torch.tensor(seg_vals, device=dev), torch.tensor(offs_p, device=dev))
dedup[s] = (torch.tensor(u_vals, device=dev), torch.tensor(offs_d, device=dev),
torch.tensor(u_w, dtype=torch.float32, device=dev))
with torch.no_grad():
infer.CONFIG["use_embedding_bag"] = True
ref = enc(plain)
new = enc(dedup)
infer.CONFIG["use_embedding_bag"] = False
err = (ref - new).abs().max().item()
assert torch.allclose(ref, new, atol=1e-3, rtol=1e-3), f"collate_dedup 不等价 max err={err:.3e}"
print(f"[PASS] collate_dedup(去重+计数) == 全展开 (max err={err:.2e}, dev={dev})")
def test_embedding_bag_matches():
torch.manual_seed(0)
dev = "cuda" if torch.cuda.is_available() else "cpu"
slot_num, emb_dim, d_model = 28, 512, 512
enc = infer.RepEncoder(vocab_size=200, emb_dim=emb_dim, slot_num=slot_num,
d_model=d_model).to(dev).eval()
N = 6
batch = {}
for s in range(1, slot_num + 1):
counts = torch.randint(0, 8, (N,))
vals = torch.randint(0, 200, (int(counts.sum()),), device=dev)
offs = torch.cat([torch.zeros(1, dtype=torch.long), counts.cumsum(0)]).to(dev)
batch[s] = (vals, offs)
with torch.no_grad():
infer.CONFIG["use_embedding_bag"] = False
ref = enc(batch)
infer.CONFIG["use_embedding_bag"] = True
new = enc(batch)
infer.CONFIG["use_embedding_bag"] = False
err = (ref - new).abs().max().item()
assert torch.allclose(ref, new, atol=1e-3, rtol=1e-3), f"embedding_bag 不等价 max err={err:.3e}"
print(f"[PASS] embedding_bag == segment_reduce (max err={err:.2e}, dev={dev})")
def test_sparse_pool_matches():
torch.manual_seed(0)
dev = "cuda" if torch.cuda.is_available() else "cpu"
slot_num, emb_dim, d_model = 28, 512, 512
enc = infer.RepEncoder(vocab_size=200, emb_dim=emb_dim, slot_num=slot_num,
d_model=d_model).to(dev).eval()
N = 6
batch = {}
for s in range(1, slot_num + 1):
counts = torch.randint(0, 8, (N,))
# 故意制造段内重复:值域很小,重复率高
vals = torch.randint(0, 30, (int(counts.sum()),), device=dev)
offs = torch.cat([torch.zeros(1, dtype=torch.long), counts.cumsum(0)]).to(dev)
batch[s] = (vals, offs)
with torch.no_grad():
infer.CONFIG["sparse_pool"] = False
infer.CONFIG["dedup_embedding"] = True
ref = enc(batch)
infer.CONFIG["sparse_pool"] = True
new = enc(batch)
infer.CONFIG["sparse_pool"] = False
err = (ref - new).abs().max().item()
assert torch.allclose(ref, new, atol=2e-2, rtol=2e-2), f"sparse_pool 不等价 max err={err:.3e}"
print(f"[PASS] sparse_pool == segment_reduce (max err={err:.2e}, dev={dev})")
def test_triton_varlen_matches_dense():
if not (torch.cuda.is_available() and infer._HAS_TRITON):
print("[SKIP] Triton varlen 等价测试(需 CUDA + triton")
return
torch.manual_seed(0)
dev = "cuda"
H, Dh = 8, 64
offs = _offsets([10, 64, 1, 130, 64, 200], dev) # 含跨多块/单token/正好整块的段
S = int(offs[-1])
q = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
k = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
v = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
with torch.no_grad():
dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]})
meta = infer._triton_block_meta(offs, 64, q.device, S)
trit = infer.scaled_dot_product(q, k, v, {"triton_meta": meta})
err = (dense.float() - trit.float()).abs().max().item()
assert torch.allclose(dense.float(), trit.float(), atol=3e-2, rtol=3e-2), \
f"Triton varlen 不等价 max err={err:.3e}"
print(f"[PASS] Triton varlen flash == 稠密SDPA (max err={err:.2e})")
def test_syncfree_mask_matches():
dev = "cuda" if torch.cuda.is_available() else "cpu"
rep = infer.RepEncoder(vocab_size=100, emb_dim=8, slot_num=28, d_model=8)
seq = infer.TransformerEncoder(d_model=8, n_heads=2, num_layers=1, dim_ff=16)
model = infer.CTRModel(rep, seq, d_model=8).to(dev)
offs = torch.tensor([0, 10, 35, 42, 60], device=dev) # 4 个用户,变长
S = int(offs[-1])
m1 = model.get_sequence_causal_mask(offs)
m2 = model.causal_mask_syncfree(offs, S, torch.device(dev))
assert torch.equal(m1, m2), "sync-free mask 与原 mask 不一致"
print(f"[PASS] searchsorted mask == repeat_interleave mask (dev={dev})")
def test_varlen_matches_dense_attention():
if not torch.cuda.is_available():
print("[SKIP] varlen 等价测试(需 CUDA")
return
torch.manual_seed(0)
dev = "cuda"
H, Dh = 8, 64
offs = _offsets([10, 25, 7, 40, 18], dev)
S = int(offs[-1])
q = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
k = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
v = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
with torch.no_grad():
dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]})
varlen = infer.scaled_dot_product(q, k, v, {"varlen_offsets": offs})
err = (dense.float() - varlen.float()).abs().max().item()
assert torch.allclose(dense.float(), varlen.float(), atol=2e-2, rtol=2e-2), \
f"varlen 不等价 max err={err:.3e}"
print(f"[PASS] varlen(嵌套张量) == 稠密SDPA (max err={err:.2e})")
def test_sparse_moe_matches_dense():
# 大 capacity(无丢弃)下,稀疏 MoE 应与 dense 数学等价
torch.manual_seed(0)
dev = "cuda" if torch.cuda.is_available() else "cpu"
m = infer.SMoE(d_model=512, dim_ff=1024, num_experts=8, k=2).to(dev).eval()
x = torch.randn(1, 200, 512, device=dev)
with torch.no_grad():
infer.CONFIG["moe_sparse"] = False
ref, _ = m(x)
infer.CONFIG["moe_sparse"] = True
infer.CONFIG["moe_capacity"] = 8.0 # 足够大,不丢 token
new, _ = m(x)
infer.CONFIG["moe_sparse"] = False
infer.CONFIG["moe_capacity"] = 1.25
err = (ref - new).abs().max().item()
assert torch.allclose(ref, new, atol=1e-3, rtol=1e-3), f"sparse MoE 不等价 max err={err:.3e}"
print(f"[PASS] sparse MoE(大capacity) == dense (max err={err:.2e}, dev={dev})")
def test_fused_embedding_matches_perslot():
torch.manual_seed(0)
dev = "cuda" if torch.cuda.is_available() else "cpu"
slot_num, emb_dim, d_model = 28, 512, 512
enc = infer.RepEncoder(vocab_size=10000, emb_dim=emb_dim, slot_num=slot_num,
d_model=d_model).to(dev).eval()
# 造一个 N=6 样本的 batch:每 slot 每样本 0~4 个 sign(含空 slot 边界)
N = 6
batch = {}
for s in range(1, slot_num + 1):
counts = torch.randint(0, 5, (N,))
vals = torch.randint(0, 10000, (int(counts.sum()),), device=dev)
offs = torch.cat([torch.zeros(1, dtype=torch.long), counts.cumsum(0)]).to(dev)
batch[s] = (vals, offs)
with torch.no_grad():
infer.CONFIG["fuse_embedding"] = False
ref = enc(batch)
infer.CONFIG["fuse_embedding"] = True
new = enc(batch)
err = (ref - new).abs().max().item()
assert torch.allclose(ref, new, atol=1e-4, rtol=1e-4), f"embedding融合不等价 max err={err:.3e}"
print(f"[PASS] embedding 融合 == 逐slot (max err={err:.2e}, dev={dev})")
def test_flex_matches_dense_attention():
ok = (torch.cuda.is_available() and infer._HAS_FLEX
and torch.cuda.get_device_capability()[0] >= 8)
if not ok:
print("[SKIP] FlexAttention 等价测试(需 CUDA SM80+")
return
torch.manual_seed(0)
dev = "cuda"
H, Dh = 8, 64
offs = _offsets([10, 25, 7, 40, 18], dev)
S = int(offs[-1])
q = torch.randn(1, H, S, Dh, device=dev)
k = torch.randn(1, H, S, Dh, device=dev)
v = torch.randn(1, H, S, Dh, device=dev)
with torch.no_grad():
dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]})
flex = infer.scaled_dot_product(q, k, v, {"block_mask": _block_mask(offs, S)})
err = (dense - flex).abs().max().item()
assert torch.allclose(dense, flex, atol=2e-2, rtol=2e-2), f"Flex 不等价 max err={err:.3e}"
print(f"[PASS] FlexAttention 块对角 == 稠密SDPA (max err={err:.2e})")
if __name__ == "__main__":
test_moe_dense_matches_loop()
test_sparse_moe_matches_dense()
test_fused_embedding_matches_perslot()
test_embedding_bag_matches()
test_collate_dedup_matches()
test_sparse_pool_matches()
test_syncfree_mask_matches()
test_triton_varlen_matches_dense()
test_chunked_matches_dense_attention()
test_varlen_matches_dense_attention()
test_flex_matches_dense_attention()
print("[DONE] 等价测试结束")