Files
CTI-Inference-Opt/代码/code/tests/test_equiv.py
T
OwnerSunshine530 cdc2dd490b feat: Triton varlen因果flash attention(块对角,单kernel,消逐块调用+mask构造开销)
每program处理(用户段query块,head),只遍历段内<=该块的key(因果),在线softmax,
fp16读写fp32累加。CONFIG.attn=triton(默认仍chunked);_triton_block_meta每batch算一次
block→段映射8层复用;_resolve_attn在无triton/CPU时回退chunked。等价测试+bench --attn triton。
数学等价(FlashAttention同类,规则允许),不改组网。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-17 00:14:53 +08:00

227 lines
9.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Phase B 数值等价测试:新实现 vs 原实现。子进程跑:
%cd /home/aistudio/code
!python tests/test_equiv.py
- MoE 稠密向量化 vs 原逐 expert 循环(CPU/GPU 都可,FP32
- FlexAttention 块对角因果 vs 稠密 SDPA(需 CUDA SM80+,否则自动跳过)
"""
import os
import sys
# baseline 把依赖装在 --target 目录;import 前补 sys.path
for _p in ("/home/aistudio/external-libraries", "/home/aistudio/libraries",
os.path.abspath("../libraries"), os.path.abspath("./libraries")):
if os.path.isdir(_p) and _p not in sys.path:
sys.path.insert(0, _p)
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import torch
import torch.nn.functional as F
import infer
def _offsets(lengths, device):
offs = [0]
for L in lengths:
offs.append(offs[-1] + L)
return torch.tensor(offs, dtype=torch.long, device=device)
def _dense_causal_mask(offs):
"""同用户 + 因果(tril),与 CTRModel.get_sequence_causal_mask 语义一致。"""
lengths = (offs[1:] - offs[:-1]).view(-1)
idx = torch.repeat_interleave(
torch.arange(lengths.numel(), device=offs.device), lengths)
same = idx.view(1, -1) == idx.view(-1, 1)
causal = torch.tril(torch.ones_like(same, dtype=torch.bool))
return same & causal
def _block_mask(offs, S):
lengths = (offs[1:] - offs[:-1]).view(-1)
doc_id = torch.repeat_interleave(
torch.arange(lengths.numel(), device=offs.device), lengths)
def mask_mod(b, h, q_idx, kv_idx):
return (q_idx >= kv_idx) & (doc_id[q_idx] == doc_id[kv_idx])
return infer.create_block_mask(mask_mod, B=None, H=None, Q_LEN=S, KV_LEN=S,
device=offs.device)
def test_moe_dense_matches_loop():
torch.manual_seed(0)
dev = "cuda" if torch.cuda.is_available() else "cpu"
moe = infer.SMoE(d_model=512, dim_ff=1024, num_experts=8, k=2).to(dev).eval()
x = torch.randn(1, 200, 512, device=dev)
with torch.no_grad():
ref, _ = infer._smoe_forward_loop(moe, x)
infer.CONFIG["vectorize_moe"] = True
new, _ = moe(x)
err = (ref - new).abs().max().item()
assert torch.allclose(ref, new, atol=1e-4, rtol=1e-4), f"MoE 不等价 max err={err:.3e}"
print(f"[PASS] MoE 稠密向量化 == 逐expert循环 (max err={err:.2e}, dev={dev})")
def test_chunked_matches_dense_attention():
dev = "cuda" if torch.cuda.is_available() else "cpu"
rep = infer.RepEncoder(vocab_size=100, emb_dim=8, slot_num=28, d_model=8)
seq = infer.TransformerEncoder(d_model=8, n_heads=2, num_layers=1, dim_ff=16)
model = infer.CTRModel(rep, seq, d_model=8).to(dev)
torch.manual_seed(0)
H, Dh = 8, 64
offs = _offsets([10, 25, 7, 40, 18, 5, 33], dev) # 7 个用户
S = int(offs[-1])
q = torch.randn(1, H, S, Dh, device=dev)
k = torch.randn(1, H, S, Dh, device=dev)
v = torch.randn(1, H, S, Dh, device=dev)
with torch.no_grad():
dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]})
infer.CONFIG["chunk_users"] = 3 # 每块 3 个用户
chunks = model.build_chunks(offs, torch.device(dev))
chunked = infer.scaled_dot_product(q, k, v, {"chunks": chunks})
err = (dense - chunked).abs().max().item()
assert torch.allclose(dense, chunked, atol=1e-4, rtol=1e-4), f"chunked 不等价 max err={err:.3e}"
print(f"[PASS] chunked SDPA == 稠密SDPA (max err={err:.2e}, dev={dev})")
def test_sparse_pool_matches():
torch.manual_seed(0)
dev = "cuda" if torch.cuda.is_available() else "cpu"
slot_num, emb_dim, d_model = 28, 512, 512
enc = infer.RepEncoder(vocab_size=200, emb_dim=emb_dim, slot_num=slot_num,
d_model=d_model).to(dev).eval()
N = 6
batch = {}
for s in range(1, slot_num + 1):
counts = torch.randint(0, 8, (N,))
# 故意制造段内重复:值域很小,重复率高
vals = torch.randint(0, 30, (int(counts.sum()),), device=dev)
offs = torch.cat([torch.zeros(1, dtype=torch.long), counts.cumsum(0)]).to(dev)
batch[s] = (vals, offs)
with torch.no_grad():
infer.CONFIG["sparse_pool"] = False
infer.CONFIG["dedup_embedding"] = True
ref = enc(batch)
infer.CONFIG["sparse_pool"] = True
new = enc(batch)
infer.CONFIG["sparse_pool"] = False
err = (ref - new).abs().max().item()
assert torch.allclose(ref, new, atol=2e-2, rtol=2e-2), f"sparse_pool 不等价 max err={err:.3e}"
print(f"[PASS] sparse_pool == segment_reduce (max err={err:.2e}, dev={dev})")
def test_triton_varlen_matches_dense():
if not (torch.cuda.is_available() and infer._HAS_TRITON):
print("[SKIP] Triton varlen 等价测试(需 CUDA + triton")
return
torch.manual_seed(0)
dev = "cuda"
H, Dh = 8, 64
offs = _offsets([10, 64, 1, 130, 64, 200], dev) # 含跨多块/单token/正好整块的段
S = int(offs[-1])
q = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
k = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
v = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
with torch.no_grad():
dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]})
meta = infer._triton_block_meta(offs, 64, q.device)
trit = infer.scaled_dot_product(q, k, v, {"triton_meta": meta})
err = (dense.float() - trit.float()).abs().max().item()
assert torch.allclose(dense.float(), trit.float(), atol=3e-2, rtol=3e-2), \
f"Triton varlen 不等价 max err={err:.3e}"
print(f"[PASS] Triton varlen flash == 稠密SDPA (max err={err:.2e})")
def test_syncfree_mask_matches():
dev = "cuda" if torch.cuda.is_available() else "cpu"
rep = infer.RepEncoder(vocab_size=100, emb_dim=8, slot_num=28, d_model=8)
seq = infer.TransformerEncoder(d_model=8, n_heads=2, num_layers=1, dim_ff=16)
model = infer.CTRModel(rep, seq, d_model=8).to(dev)
offs = torch.tensor([0, 10, 35, 42, 60], device=dev) # 4 个用户,变长
S = int(offs[-1])
m1 = model.get_sequence_causal_mask(offs)
m2 = model.causal_mask_syncfree(offs, S, torch.device(dev))
assert torch.equal(m1, m2), "sync-free mask 与原 mask 不一致"
print(f"[PASS] searchsorted mask == repeat_interleave mask (dev={dev})")
def test_varlen_matches_dense_attention():
if not torch.cuda.is_available():
print("[SKIP] varlen 等价测试(需 CUDA")
return
torch.manual_seed(0)
dev = "cuda"
H, Dh = 8, 64
offs = _offsets([10, 25, 7, 40, 18], dev)
S = int(offs[-1])
q = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
k = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
v = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
with torch.no_grad():
dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]})
varlen = infer.scaled_dot_product(q, k, v, {"varlen_offsets": offs})
err = (dense.float() - varlen.float()).abs().max().item()
assert torch.allclose(dense.float(), varlen.float(), atol=2e-2, rtol=2e-2), \
f"varlen 不等价 max err={err:.3e}"
print(f"[PASS] varlen(嵌套张量) == 稠密SDPA (max err={err:.2e})")
def test_fused_embedding_matches_perslot():
torch.manual_seed(0)
dev = "cuda" if torch.cuda.is_available() else "cpu"
slot_num, emb_dim, d_model = 28, 512, 512
enc = infer.RepEncoder(vocab_size=10000, emb_dim=emb_dim, slot_num=slot_num,
d_model=d_model).to(dev).eval()
# 造一个 N=6 样本的 batch:每 slot 每样本 0~4 个 sign(含空 slot 边界)
N = 6
batch = {}
for s in range(1, slot_num + 1):
counts = torch.randint(0, 5, (N,))
vals = torch.randint(0, 10000, (int(counts.sum()),), device=dev)
offs = torch.cat([torch.zeros(1, dtype=torch.long), counts.cumsum(0)]).to(dev)
batch[s] = (vals, offs)
with torch.no_grad():
infer.CONFIG["fuse_embedding"] = False
ref = enc(batch)
infer.CONFIG["fuse_embedding"] = True
new = enc(batch)
err = (ref - new).abs().max().item()
assert torch.allclose(ref, new, atol=1e-4, rtol=1e-4), f"embedding融合不等价 max err={err:.3e}"
print(f"[PASS] embedding 融合 == 逐slot (max err={err:.2e}, dev={dev})")
def test_flex_matches_dense_attention():
ok = (torch.cuda.is_available() and infer._HAS_FLEX
and torch.cuda.get_device_capability()[0] >= 8)
if not ok:
print("[SKIP] FlexAttention 等价测试(需 CUDA SM80+")
return
torch.manual_seed(0)
dev = "cuda"
H, Dh = 8, 64
offs = _offsets([10, 25, 7, 40, 18], dev)
S = int(offs[-1])
q = torch.randn(1, H, S, Dh, device=dev)
k = torch.randn(1, H, S, Dh, device=dev)
v = torch.randn(1, H, S, Dh, device=dev)
with torch.no_grad():
dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]})
flex = infer.scaled_dot_product(q, k, v, {"block_mask": _block_mask(offs, S)})
err = (dense - flex).abs().max().item()
assert torch.allclose(dense, flex, atol=2e-2, rtol=2e-2), f"Flex 不等价 max err={err:.3e}"
print(f"[PASS] FlexAttention 块对角 == 稠密SDPA (max err={err:.2e})")
if __name__ == "__main__":
test_moe_dense_matches_loop()
test_fused_embedding_matches_perslot()
test_sparse_pool_matches()
test_syncfree_mask_matches()
test_triton_varlen_matches_dense()
test_chunked_matches_dense_attention()
test_varlen_matches_dense_attention()
test_flex_matches_dense_attention()
print("[DONE] 等价测试结束")