cdc2dd490b
每program处理(用户段query块,head),只遍历段内<=该块的key(因果),在线softmax, fp16读写fp32累加。CONFIG.attn=triton(默认仍chunked);_triton_block_meta每batch算一次 block→段映射8层复用;_resolve_attn在无triton/CPU时回退chunked。等价测试+bench --attn triton。 数学等价(FlashAttention同类,规则允许),不改组网。 Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
227 lines
9.6 KiB
Python
227 lines
9.6 KiB
Python
"""Phase B 数值等价测试:新实现 vs 原实现。子进程跑:
|
||
|
||
%cd /home/aistudio/code
|
||
!python tests/test_equiv.py
|
||
|
||
- MoE 稠密向量化 vs 原逐 expert 循环(CPU/GPU 都可,FP32)
|
||
- FlexAttention 块对角因果 vs 稠密 SDPA(需 CUDA SM80+,否则自动跳过)
|
||
"""
|
||
import os
|
||
import sys
|
||
|
||
# baseline 把依赖装在 --target 目录;import 前补 sys.path
|
||
for _p in ("/home/aistudio/external-libraries", "/home/aistudio/libraries",
|
||
os.path.abspath("../libraries"), os.path.abspath("./libraries")):
|
||
if os.path.isdir(_p) and _p not in sys.path:
|
||
sys.path.insert(0, _p)
|
||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||
|
||
import torch
|
||
import torch.nn.functional as F
|
||
import infer
|
||
|
||
|
||
def _offsets(lengths, device):
|
||
offs = [0]
|
||
for L in lengths:
|
||
offs.append(offs[-1] + L)
|
||
return torch.tensor(offs, dtype=torch.long, device=device)
|
||
|
||
|
||
def _dense_causal_mask(offs):
|
||
"""同用户 + 因果(tril),与 CTRModel.get_sequence_causal_mask 语义一致。"""
|
||
lengths = (offs[1:] - offs[:-1]).view(-1)
|
||
idx = torch.repeat_interleave(
|
||
torch.arange(lengths.numel(), device=offs.device), lengths)
|
||
same = idx.view(1, -1) == idx.view(-1, 1)
|
||
causal = torch.tril(torch.ones_like(same, dtype=torch.bool))
|
||
return same & causal
|
||
|
||
|
||
def _block_mask(offs, S):
|
||
lengths = (offs[1:] - offs[:-1]).view(-1)
|
||
doc_id = torch.repeat_interleave(
|
||
torch.arange(lengths.numel(), device=offs.device), lengths)
|
||
|
||
def mask_mod(b, h, q_idx, kv_idx):
|
||
return (q_idx >= kv_idx) & (doc_id[q_idx] == doc_id[kv_idx])
|
||
|
||
return infer.create_block_mask(mask_mod, B=None, H=None, Q_LEN=S, KV_LEN=S,
|
||
device=offs.device)
|
||
|
||
|
||
def test_moe_dense_matches_loop():
|
||
torch.manual_seed(0)
|
||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||
moe = infer.SMoE(d_model=512, dim_ff=1024, num_experts=8, k=2).to(dev).eval()
|
||
x = torch.randn(1, 200, 512, device=dev)
|
||
with torch.no_grad():
|
||
ref, _ = infer._smoe_forward_loop(moe, x)
|
||
infer.CONFIG["vectorize_moe"] = True
|
||
new, _ = moe(x)
|
||
err = (ref - new).abs().max().item()
|
||
assert torch.allclose(ref, new, atol=1e-4, rtol=1e-4), f"MoE 不等价 max err={err:.3e}"
|
||
print(f"[PASS] MoE 稠密向量化 == 逐expert循环 (max err={err:.2e}, dev={dev})")
|
||
|
||
|
||
def test_chunked_matches_dense_attention():
|
||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||
rep = infer.RepEncoder(vocab_size=100, emb_dim=8, slot_num=28, d_model=8)
|
||
seq = infer.TransformerEncoder(d_model=8, n_heads=2, num_layers=1, dim_ff=16)
|
||
model = infer.CTRModel(rep, seq, d_model=8).to(dev)
|
||
torch.manual_seed(0)
|
||
H, Dh = 8, 64
|
||
offs = _offsets([10, 25, 7, 40, 18, 5, 33], dev) # 7 个用户
|
||
S = int(offs[-1])
|
||
q = torch.randn(1, H, S, Dh, device=dev)
|
||
k = torch.randn(1, H, S, Dh, device=dev)
|
||
v = torch.randn(1, H, S, Dh, device=dev)
|
||
with torch.no_grad():
|
||
dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]})
|
||
infer.CONFIG["chunk_users"] = 3 # 每块 3 个用户
|
||
chunks = model.build_chunks(offs, torch.device(dev))
|
||
chunked = infer.scaled_dot_product(q, k, v, {"chunks": chunks})
|
||
err = (dense - chunked).abs().max().item()
|
||
assert torch.allclose(dense, chunked, atol=1e-4, rtol=1e-4), f"chunked 不等价 max err={err:.3e}"
|
||
print(f"[PASS] chunked SDPA == 稠密SDPA (max err={err:.2e}, dev={dev})")
|
||
|
||
|
||
def test_sparse_pool_matches():
|
||
torch.manual_seed(0)
|
||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||
slot_num, emb_dim, d_model = 28, 512, 512
|
||
enc = infer.RepEncoder(vocab_size=200, emb_dim=emb_dim, slot_num=slot_num,
|
||
d_model=d_model).to(dev).eval()
|
||
N = 6
|
||
batch = {}
|
||
for s in range(1, slot_num + 1):
|
||
counts = torch.randint(0, 8, (N,))
|
||
# 故意制造段内重复:值域很小,重复率高
|
||
vals = torch.randint(0, 30, (int(counts.sum()),), device=dev)
|
||
offs = torch.cat([torch.zeros(1, dtype=torch.long), counts.cumsum(0)]).to(dev)
|
||
batch[s] = (vals, offs)
|
||
with torch.no_grad():
|
||
infer.CONFIG["sparse_pool"] = False
|
||
infer.CONFIG["dedup_embedding"] = True
|
||
ref = enc(batch)
|
||
infer.CONFIG["sparse_pool"] = True
|
||
new = enc(batch)
|
||
infer.CONFIG["sparse_pool"] = False
|
||
err = (ref - new).abs().max().item()
|
||
assert torch.allclose(ref, new, atol=2e-2, rtol=2e-2), f"sparse_pool 不等价 max err={err:.3e}"
|
||
print(f"[PASS] sparse_pool == segment_reduce (max err={err:.2e}, dev={dev})")
|
||
|
||
|
||
def test_triton_varlen_matches_dense():
|
||
if not (torch.cuda.is_available() and infer._HAS_TRITON):
|
||
print("[SKIP] Triton varlen 等价测试(需 CUDA + triton)")
|
||
return
|
||
torch.manual_seed(0)
|
||
dev = "cuda"
|
||
H, Dh = 8, 64
|
||
offs = _offsets([10, 64, 1, 130, 64, 200], dev) # 含跨多块/单token/正好整块的段
|
||
S = int(offs[-1])
|
||
q = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
|
||
k = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
|
||
v = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
|
||
with torch.no_grad():
|
||
dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]})
|
||
meta = infer._triton_block_meta(offs, 64, q.device)
|
||
trit = infer.scaled_dot_product(q, k, v, {"triton_meta": meta})
|
||
err = (dense.float() - trit.float()).abs().max().item()
|
||
assert torch.allclose(dense.float(), trit.float(), atol=3e-2, rtol=3e-2), \
|
||
f"Triton varlen 不等价 max err={err:.3e}"
|
||
print(f"[PASS] Triton varlen flash == 稠密SDPA (max err={err:.2e})")
|
||
|
||
|
||
def test_syncfree_mask_matches():
|
||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||
rep = infer.RepEncoder(vocab_size=100, emb_dim=8, slot_num=28, d_model=8)
|
||
seq = infer.TransformerEncoder(d_model=8, n_heads=2, num_layers=1, dim_ff=16)
|
||
model = infer.CTRModel(rep, seq, d_model=8).to(dev)
|
||
offs = torch.tensor([0, 10, 35, 42, 60], device=dev) # 4 个用户,变长
|
||
S = int(offs[-1])
|
||
m1 = model.get_sequence_causal_mask(offs)
|
||
m2 = model.causal_mask_syncfree(offs, S, torch.device(dev))
|
||
assert torch.equal(m1, m2), "sync-free mask 与原 mask 不一致"
|
||
print(f"[PASS] searchsorted mask == repeat_interleave mask (dev={dev})")
|
||
|
||
|
||
def test_varlen_matches_dense_attention():
|
||
if not torch.cuda.is_available():
|
||
print("[SKIP] varlen 等价测试(需 CUDA)")
|
||
return
|
||
torch.manual_seed(0)
|
||
dev = "cuda"
|
||
H, Dh = 8, 64
|
||
offs = _offsets([10, 25, 7, 40, 18], dev)
|
||
S = int(offs[-1])
|
||
q = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
|
||
k = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
|
||
v = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
|
||
with torch.no_grad():
|
||
dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]})
|
||
varlen = infer.scaled_dot_product(q, k, v, {"varlen_offsets": offs})
|
||
err = (dense.float() - varlen.float()).abs().max().item()
|
||
assert torch.allclose(dense.float(), varlen.float(), atol=2e-2, rtol=2e-2), \
|
||
f"varlen 不等价 max err={err:.3e}"
|
||
print(f"[PASS] varlen(嵌套张量) == 稠密SDPA (max err={err:.2e})")
|
||
|
||
|
||
def test_fused_embedding_matches_perslot():
|
||
torch.manual_seed(0)
|
||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||
slot_num, emb_dim, d_model = 28, 512, 512
|
||
enc = infer.RepEncoder(vocab_size=10000, emb_dim=emb_dim, slot_num=slot_num,
|
||
d_model=d_model).to(dev).eval()
|
||
# 造一个 N=6 样本的 batch:每 slot 每样本 0~4 个 sign(含空 slot 边界)
|
||
N = 6
|
||
batch = {}
|
||
for s in range(1, slot_num + 1):
|
||
counts = torch.randint(0, 5, (N,))
|
||
vals = torch.randint(0, 10000, (int(counts.sum()),), device=dev)
|
||
offs = torch.cat([torch.zeros(1, dtype=torch.long), counts.cumsum(0)]).to(dev)
|
||
batch[s] = (vals, offs)
|
||
with torch.no_grad():
|
||
infer.CONFIG["fuse_embedding"] = False
|
||
ref = enc(batch)
|
||
infer.CONFIG["fuse_embedding"] = True
|
||
new = enc(batch)
|
||
err = (ref - new).abs().max().item()
|
||
assert torch.allclose(ref, new, atol=1e-4, rtol=1e-4), f"embedding融合不等价 max err={err:.3e}"
|
||
print(f"[PASS] embedding 融合 == 逐slot (max err={err:.2e}, dev={dev})")
|
||
|
||
|
||
def test_flex_matches_dense_attention():
|
||
ok = (torch.cuda.is_available() and infer._HAS_FLEX
|
||
and torch.cuda.get_device_capability()[0] >= 8)
|
||
if not ok:
|
||
print("[SKIP] FlexAttention 等价测试(需 CUDA SM80+)")
|
||
return
|
||
torch.manual_seed(0)
|
||
dev = "cuda"
|
||
H, Dh = 8, 64
|
||
offs = _offsets([10, 25, 7, 40, 18], dev)
|
||
S = int(offs[-1])
|
||
q = torch.randn(1, H, S, Dh, device=dev)
|
||
k = torch.randn(1, H, S, Dh, device=dev)
|
||
v = torch.randn(1, H, S, Dh, device=dev)
|
||
with torch.no_grad():
|
||
dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]})
|
||
flex = infer.scaled_dot_product(q, k, v, {"block_mask": _block_mask(offs, S)})
|
||
err = (dense - flex).abs().max().item()
|
||
assert torch.allclose(dense, flex, atol=2e-2, rtol=2e-2), f"Flex 不等价 max err={err:.3e}"
|
||
print(f"[PASS] FlexAttention 块对角 == 稠密SDPA (max err={err:.2e})")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
test_moe_dense_matches_loop()
|
||
test_fused_embedding_matches_perslot()
|
||
test_sparse_pool_matches()
|
||
test_syncfree_mask_matches()
|
||
test_triton_varlen_matches_dense()
|
||
test_chunked_matches_dense_attention()
|
||
test_varlen_matches_dense_attention()
|
||
test_flex_matches_dense_attention()
|
||
print("[DONE] 等价测试结束")
|