Files
CTI-Inference-Opt/代码/code/tests/test_equiv.py
T
OwnerSunshine530 c1d8b91fb2 feat(Phase B): FlexAttention 块对角注意力 + MoE 稠密向量化
- scaled_dot_product 分发:block_mask->FlexAttention(每用户仅自身序列内因果,
  避免对~14000长拼接序列做O(S²)稠密注意力);否则SDPA稠密(回退/对照)。
- CTRModel.build_block_mask 构造块对角因果mask;_use_flex 在SM80+自动启用。
- SMoE 稠密向量化(einsum批量算所有expert后按top-k gather),消除Python循环/同步;
  保留 _smoe_forward_loop 作数值等价对照。CONFIG.vectorize_moe 可切。
- load_model 加可选 torch.compile。
- tests/test_equiv.py:MoE稠密vs循环、Flex vs稠密SDPA 数值等价(无pytest依赖)。
- bench.py 加 --attn/--moe/--compile 便于A800上对比测速。

需 A800(SM80) 实测;CPU/V100 自动回退 SDPA。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-14 23:30:59 +08:00

93 lines
3.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Phase B 数值等价测试:新实现 vs 原实现。子进程跑:
%cd /home/aistudio/code
!python tests/test_equiv.py
- MoE 稠密向量化 vs 原逐 expert 循环(CPU/GPU 都可,FP32
- FlexAttention 块对角因果 vs 稠密 SDPA(需 CUDA SM80+,否则自动跳过)
"""
import os
import sys
# baseline 把依赖装在 --target 目录;import 前补 sys.path
for _p in ("/home/aistudio/external-libraries", "/home/aistudio/libraries",
os.path.abspath("../libraries"), os.path.abspath("./libraries")):
if os.path.isdir(_p) and _p not in sys.path:
sys.path.insert(0, _p)
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import torch
import torch.nn.functional as F
import infer
def _offsets(lengths, device):
offs = [0]
for L in lengths:
offs.append(offs[-1] + L)
return torch.tensor(offs, dtype=torch.long, device=device)
def _dense_causal_mask(offs):
"""同用户 + 因果(tril),与 CTRModel.get_sequence_causal_mask 语义一致。"""
lengths = (offs[1:] - offs[:-1]).view(-1)
idx = torch.repeat_interleave(
torch.arange(lengths.numel(), device=offs.device), lengths)
same = idx.view(1, -1) == idx.view(-1, 1)
causal = torch.tril(torch.ones_like(same, dtype=torch.bool))
return same & causal
def _block_mask(offs, S):
lengths = (offs[1:] - offs[:-1]).view(-1)
doc_id = torch.repeat_interleave(
torch.arange(lengths.numel(), device=offs.device), lengths)
def mask_mod(b, h, q_idx, kv_idx):
return (q_idx >= kv_idx) & (doc_id[q_idx] == doc_id[kv_idx])
return infer.create_block_mask(mask_mod, B=None, H=None, Q_LEN=S, KV_LEN=S,
device=offs.device)
def test_moe_dense_matches_loop():
torch.manual_seed(0)
dev = "cuda" if torch.cuda.is_available() else "cpu"
moe = infer.SMoE(d_model=512, dim_ff=1024, num_experts=8, k=2).to(dev).eval()
x = torch.randn(1, 200, 512, device=dev)
with torch.no_grad():
ref, _ = infer._smoe_forward_loop(moe, x)
infer.CONFIG["vectorize_moe"] = True
new, _ = moe(x)
err = (ref - new).abs().max().item()
assert torch.allclose(ref, new, atol=1e-4, rtol=1e-4), f"MoE 不等价 max err={err:.3e}"
print(f"[PASS] MoE 稠密向量化 == 逐expert循环 (max err={err:.2e}, dev={dev})")
def test_flex_matches_dense_attention():
ok = (torch.cuda.is_available() and infer._HAS_FLEX
and torch.cuda.get_device_capability()[0] >= 8)
if not ok:
print("[SKIP] FlexAttention 等价测试(需 CUDA SM80+,当前环境不满足)")
return
torch.manual_seed(0)
dev = "cuda"
H, Dh = 8, 64
offs = _offsets([10, 25, 7, 40, 18], dev)
S = int(offs[-1])
q = torch.randn(1, H, S, Dh, device=dev)
k = torch.randn(1, H, S, Dh, device=dev)
v = torch.randn(1, H, S, Dh, device=dev)
with torch.no_grad():
dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]})
flex = infer.scaled_dot_product(q, k, v, {"block_mask": _block_mask(offs, S)})
err = (dense - flex).abs().max().item()
assert torch.allclose(dense, flex, atol=2e-2, rtol=2e-2), f"Flex 不等价 max err={err:.3e}"
print(f"[PASS] FlexAttention 块对角 == 稠密SDPA (max err={err:.2e})")
if __name__ == "__main__":
test_moe_dense_matches_loop()
test_flex_matches_dense_attention()
print("[DONE] 等价测试结束")