feat(Phase B): FlexAttention 块对角注意力 + MoE 稠密向量化

- scaled_dot_product 分发:block_mask->FlexAttention(每用户仅自身序列内因果,
  避免对~14000长拼接序列做O(S²)稠密注意力);否则SDPA稠密(回退/对照)。
- CTRModel.build_block_mask 构造块对角因果mask;_use_flex 在SM80+自动启用。
- SMoE 稠密向量化(einsum批量算所有expert后按top-k gather),消除Python循环/同步;
  保留 _smoe_forward_loop 作数值等价对照。CONFIG.vectorize_moe 可切。
- load_model 加可选 torch.compile。
- tests/test_equiv.py:MoE稠密vs循环、Flex vs稠密SDPA 数值等价(无pytest依赖)。
- bench.py 加 --attn/--moe/--compile 便于A800上对比测速。

需 A800(SM80) 实测;CPU/V100 自动回退 SDPA。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-14 23:30:59 +08:00
parent 0a971e67ac
commit c1d8b91fb2
3 changed files with 222 additions and 29 deletions
+92
View File
@@ -0,0 +1,92 @@
"""Phase B 数值等价测试:新实现 vs 原实现。子进程跑:
%cd /home/aistudio/code
!python tests/test_equiv.py
- MoE 稠密向量化 vs 原逐 expert 循环(CPU/GPU 都可,FP32
- FlexAttention 块对角因果 vs 稠密 SDPA(需 CUDA SM80+,否则自动跳过)
"""
import os
import sys
# baseline 把依赖装在 --target 目录;import 前补 sys.path
for _p in ("/home/aistudio/external-libraries", "/home/aistudio/libraries",
os.path.abspath("../libraries"), os.path.abspath("./libraries")):
if os.path.isdir(_p) and _p not in sys.path:
sys.path.insert(0, _p)
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import torch
import torch.nn.functional as F
import infer
def _offsets(lengths, device):
offs = [0]
for L in lengths:
offs.append(offs[-1] + L)
return torch.tensor(offs, dtype=torch.long, device=device)
def _dense_causal_mask(offs):
"""同用户 + 因果(tril),与 CTRModel.get_sequence_causal_mask 语义一致。"""
lengths = (offs[1:] - offs[:-1]).view(-1)
idx = torch.repeat_interleave(
torch.arange(lengths.numel(), device=offs.device), lengths)
same = idx.view(1, -1) == idx.view(-1, 1)
causal = torch.tril(torch.ones_like(same, dtype=torch.bool))
return same & causal
def _block_mask(offs, S):
lengths = (offs[1:] - offs[:-1]).view(-1)
doc_id = torch.repeat_interleave(
torch.arange(lengths.numel(), device=offs.device), lengths)
def mask_mod(b, h, q_idx, kv_idx):
return (q_idx >= kv_idx) & (doc_id[q_idx] == doc_id[kv_idx])
return infer.create_block_mask(mask_mod, B=None, H=None, Q_LEN=S, KV_LEN=S,
device=offs.device)
def test_moe_dense_matches_loop():
torch.manual_seed(0)
dev = "cuda" if torch.cuda.is_available() else "cpu"
moe = infer.SMoE(d_model=512, dim_ff=1024, num_experts=8, k=2).to(dev).eval()
x = torch.randn(1, 200, 512, device=dev)
with torch.no_grad():
ref, _ = infer._smoe_forward_loop(moe, x)
infer.CONFIG["vectorize_moe"] = True
new, _ = moe(x)
err = (ref - new).abs().max().item()
assert torch.allclose(ref, new, atol=1e-4, rtol=1e-4), f"MoE 不等价 max err={err:.3e}"
print(f"[PASS] MoE 稠密向量化 == 逐expert循环 (max err={err:.2e}, dev={dev})")
def test_flex_matches_dense_attention():
ok = (torch.cuda.is_available() and infer._HAS_FLEX
and torch.cuda.get_device_capability()[0] >= 8)
if not ok:
print("[SKIP] FlexAttention 等价测试(需 CUDA SM80+,当前环境不满足)")
return
torch.manual_seed(0)
dev = "cuda"
H, Dh = 8, 64
offs = _offsets([10, 25, 7, 40, 18], dev)
S = int(offs[-1])
q = torch.randn(1, H, S, Dh, device=dev)
k = torch.randn(1, H, S, Dh, device=dev)
v = torch.randn(1, H, S, Dh, device=dev)
with torch.no_grad():
dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]})
flex = infer.scaled_dot_product(q, k, v, {"block_mask": _block_mask(offs, S)})
err = (dense - flex).abs().max().item()
assert torch.allclose(dense, flex, atol=2e-2, rtol=2e-2), f"Flex 不等价 max err={err:.3e}"
print(f"[PASS] FlexAttention 块对角 == 稠密SDPA (max err={err:.2e})")
if __name__ == "__main__":
test_moe_dense_matches_loop()
test_flex_matches_dense_attention()
print("[DONE] 等价测试结束")