"""Phase B 数值等价测试:新实现 vs 原实现。子进程跑: %cd /home/aistudio/code !python tests/test_equiv.py - MoE 稠密向量化 vs 原逐 expert 循环(CPU/GPU 都可,FP32) - FlexAttention 块对角因果 vs 稠密 SDPA(需 CUDA SM80+,否则自动跳过) """ import os import sys # baseline 把依赖装在 --target 目录;import 前补 sys.path for _p in ("/home/aistudio/external-libraries", "/home/aistudio/libraries", os.path.abspath("../libraries"), os.path.abspath("./libraries")): if os.path.isdir(_p) and _p not in sys.path: sys.path.insert(0, _p) sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) import torch import torch.nn.functional as F import infer def _offsets(lengths, device): offs = [0] for L in lengths: offs.append(offs[-1] + L) return torch.tensor(offs, dtype=torch.long, device=device) def _dense_causal_mask(offs): """同用户 + 因果(tril),与 CTRModel.get_sequence_causal_mask 语义一致。""" lengths = (offs[1:] - offs[:-1]).view(-1) idx = torch.repeat_interleave( torch.arange(lengths.numel(), device=offs.device), lengths) same = idx.view(1, -1) == idx.view(-1, 1) causal = torch.tril(torch.ones_like(same, dtype=torch.bool)) return same & causal def _block_mask(offs, S): lengths = (offs[1:] - offs[:-1]).view(-1) doc_id = torch.repeat_interleave( torch.arange(lengths.numel(), device=offs.device), lengths) def mask_mod(b, h, q_idx, kv_idx): return (q_idx >= kv_idx) & (doc_id[q_idx] == doc_id[kv_idx]) return infer.create_block_mask(mask_mod, B=None, H=None, Q_LEN=S, KV_LEN=S, device=offs.device) def test_moe_dense_matches_loop(): torch.manual_seed(0) dev = "cuda" if torch.cuda.is_available() else "cpu" moe = infer.SMoE(d_model=512, dim_ff=1024, num_experts=8, k=2).to(dev).eval() x = torch.randn(1, 200, 512, device=dev) with torch.no_grad(): ref, _ = infer._smoe_forward_loop(moe, x) infer.CONFIG["vectorize_moe"] = True new, _ = moe(x) err = (ref - new).abs().max().item() assert torch.allclose(ref, new, atol=1e-4, rtol=1e-4), f"MoE 不等价 max err={err:.3e}" print(f"[PASS] MoE 稠密向量化 == 逐expert循环 (max err={err:.2e}, dev={dev})") def test_syncfree_mask_matches(): dev = "cuda" if torch.cuda.is_available() else "cpu" rep = infer.RepEncoder(vocab_size=100, emb_dim=8, slot_num=28, d_model=8) seq = infer.TransformerEncoder(d_model=8, n_heads=2, num_layers=1, dim_ff=16) model = infer.CTRModel(rep, seq, d_model=8).to(dev) offs = torch.tensor([0, 10, 35, 42, 60], device=dev) # 4 个用户,变长 S = int(offs[-1]) m1 = model.get_sequence_causal_mask(offs) m2 = model.causal_mask_syncfree(offs, S, torch.device(dev)) assert torch.equal(m1, m2), "sync-free mask 与原 mask 不一致" print(f"[PASS] searchsorted mask == repeat_interleave mask (dev={dev})") def test_varlen_matches_dense_attention(): if not torch.cuda.is_available(): print("[SKIP] varlen 等价测试(需 CUDA)") return torch.manual_seed(0) dev = "cuda" H, Dh = 8, 64 offs = _offsets([10, 25, 7, 40, 18], dev) S = int(offs[-1]) q = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16) k = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16) v = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16) with torch.no_grad(): dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]}) varlen = infer.scaled_dot_product(q, k, v, {"varlen_offsets": offs}) err = (dense.float() - varlen.float()).abs().max().item() assert torch.allclose(dense.float(), varlen.float(), atol=2e-2, rtol=2e-2), \ f"varlen 不等价 max err={err:.3e}" print(f"[PASS] varlen(嵌套张量) == 稠密SDPA (max err={err:.2e})") def test_fused_embedding_matches_perslot(): torch.manual_seed(0) dev = "cuda" if torch.cuda.is_available() else "cpu" slot_num, emb_dim, d_model = 28, 512, 512 enc = infer.RepEncoder(vocab_size=10000, emb_dim=emb_dim, slot_num=slot_num, d_model=d_model).to(dev).eval() # 造一个 N=6 样本的 batch:每 slot 每样本 0~4 个 sign(含空 slot 边界) N = 6 batch = {} for s in range(1, slot_num + 1): counts = torch.randint(0, 5, (N,)) vals = torch.randint(0, 10000, (int(counts.sum()),), device=dev) offs = torch.cat([torch.zeros(1, dtype=torch.long), counts.cumsum(0)]).to(dev) batch[s] = (vals, offs) with torch.no_grad(): infer.CONFIG["fuse_embedding"] = False ref = enc(batch) infer.CONFIG["fuse_embedding"] = True new = enc(batch) err = (ref - new).abs().max().item() assert torch.allclose(ref, new, atol=1e-4, rtol=1e-4), f"embedding融合不等价 max err={err:.3e}" print(f"[PASS] embedding 融合 == 逐slot (max err={err:.2e}, dev={dev})") def test_flex_matches_dense_attention(): ok = (torch.cuda.is_available() and infer._HAS_FLEX and torch.cuda.get_device_capability()[0] >= 8) if not ok: print("[SKIP] FlexAttention 等价测试(需 CUDA SM80+)") return torch.manual_seed(0) dev = "cuda" H, Dh = 8, 64 offs = _offsets([10, 25, 7, 40, 18], dev) S = int(offs[-1]) q = torch.randn(1, H, S, Dh, device=dev) k = torch.randn(1, H, S, Dh, device=dev) v = torch.randn(1, H, S, Dh, device=dev) with torch.no_grad(): dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]}) flex = infer.scaled_dot_product(q, k, v, {"block_mask": _block_mask(offs, S)}) err = (dense - flex).abs().max().item() assert torch.allclose(dense, flex, atol=2e-2, rtol=2e-2), f"Flex 不等价 max err={err:.3e}" print(f"[PASS] FlexAttention 块对角 == 稠密SDPA (max err={err:.2e})") if __name__ == "__main__": test_moe_dense_matches_loop() test_fused_embedding_matches_perslot() test_syncfree_mask_matches() test_varlen_matches_dense_attention() test_flex_matches_dense_attention() print("[DONE] 等价测试结束")