From c1d8b91fb21a43aeb2439c1b12d5ef05720a4e00 Mon Sep 17 00:00:00 2001 From: OwnerSunshine530 Date: Sun, 14 Jun 2026 23:30:59 +0800 Subject: [PATCH] =?UTF-8?q?feat(Phase=20B):=20FlexAttention=20=E5=9D=97?= =?UTF-8?q?=E5=AF=B9=E8=A7=92=E6=B3=A8=E6=84=8F=E5=8A=9B=20+=20MoE=20?= =?UTF-8?q?=E7=A8=A0=E5=AF=86=E5=90=91=E9=87=8F=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - scaled_dot_product 分发:block_mask->FlexAttention(每用户仅自身序列内因果, 避免对~14000长拼接序列做O(S²)稠密注意力);否则SDPA稠密(回退/对照)。 - CTRModel.build_block_mask 构造块对角因果mask;_use_flex 在SM80+自动启用。 - SMoE 稠密向量化(einsum批量算所有expert后按top-k gather),消除Python循环/同步; 保留 _smoe_forward_loop 作数值等价对照。CONFIG.vectorize_moe 可切。 - load_model 加可选 torch.compile。 - tests/test_equiv.py:MoE稠密vs循环、Flex vs稠密SDPA 数值等价(无pytest依赖)。 - bench.py 加 --attn/--moe/--compile 便于A800上对比测速。 需 A800(SM80) 实测;CPU/V100 自动回退 SDPA。 Co-Authored-By: Claude Opus 4.8 --- 代码/code/bench.py | 11 +++ 代码/code/infer.py | 148 +++++++++++++++++++++++++++------- 代码/code/tests/test_equiv.py | 92 +++++++++++++++++++++ 3 files changed, 222 insertions(+), 29 deletions(-) create mode 100644 代码/code/tests/test_equiv.py diff --git a/代码/code/bench.py b/代码/code/bench.py index 8bbbb1d..75479bf 100644 --- a/代码/code/bench.py +++ b/代码/code/bench.py @@ -250,6 +250,11 @@ def _parse_args(): help="逗号分隔的 keep_fp32_modules,如 linear,rep_encoder.input_norm") ap.add_argument("--feasign-none", action="store_true", help="不截断特征(max_feasign_per_slot=None)") + ap.add_argument("--attn", choices=["auto", "flex", "sdpa"], default=None, + help="注意力实现:flex=块对角FlexAttention, sdpa=稠密(原), auto=SM80自动") + ap.add_argument("--moe", choices=["dense", "loop"], default=None, + help="MoE实现:dense=向量化(新), loop=逐expert循环(原)") + ap.add_argument("--compile", action="store_true", help="开启 torch.compile") ap.add_argument("--rebuild", action="store_true", help="强制重建过滤缓存") return ap.parse_args() @@ -273,5 +278,11 @@ if __name__ == "__main__": cfg["merge_threshold"] = a.merge_th if a.keep is not None: cfg["keep_fp32_modules"] = tuple(x for x in a.keep.split(",") if x) + if a.attn is not None: + cfg["use_flex_attn"] = {"auto": "auto", "flex": True, "sdpa": False}[a.attn] + if a.moe is not None: + cfg["vectorize_moe"] = (a.moe == "dense") + if a.compile: + cfg["compile"] = True mf = None if a.feasign_none else {1: 2} run_once(cfg, batch_size=a.bs, max_batches=a.smoke, max_feasign_per_slot=mf, rebuild=a.rebuild) diff --git a/代码/code/infer.py b/代码/code/infer.py index ccd3d45..af8377e 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -17,6 +17,15 @@ import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from tqdm import tqdm +# FlexAttention(块对角因果注意力,需 PyTorch 2.5+ 且 GPU 计算能力 >= 8.0 / Ampere) +try: + from torch.nn.attention.flex_attention import flex_attention, create_block_mask + _HAS_FLEX = True +except Exception: + flex_attention = None + create_block_mask = None + _HAS_FLEX = False + # ============================================================ # 实验配置开关板 @@ -31,9 +40,28 @@ CONFIG = { "signid_mode": "clamp", # "clamp" 或 "modulo":处理超界 sign id 的方式 "sync_timing": False, # bench 里设 True,做 torch.cuda.synchronize 真实计时 "filter_test_users": True, # 只处理含测试样本的用户(跳过会被丢弃的用户,省算力) + "use_flex_attn": "auto", # "auto"(SM80+用flex,否则SDPA) / True / False + "vectorize_moe": True, # True=稠密向量化MoE(无Python循环/同步);False=原逐expert循环 + "compile": False, # 是否 torch.compile(图理干净后再开) } +def _use_flex(device): + """决定是否用 FlexAttention:auto 模式下仅在 SM80+(Ampere/A800)启用。""" + mode = CONFIG.get("use_flex_attn", "auto") + if not _HAS_FLEX or mode is False: + return False + if mode is True: + return True + if device is not None and device.type == "cuda": + try: + major, _ = torch.cuda.get_device_capability(device) + return major >= 8 + except Exception: + return False + return False + + def _force_fp32_io(module): """让某个模块在 FP16 模型里以 FP32 计算:输入转 FP32、输出转回 FP16。 用于 keep_fp32_modules 指定的精度敏感层(如最终输出头、LayerNorm)。""" @@ -324,7 +352,14 @@ class RepEncoder(nn.Module): def scaled_dot_product(q, k, v, extension): - """使用 PyTorch SDPA 后端(自动启用 Flash Attention / Memory Efficient Attention)""" + """注意力分发: + - 若 extension 带 block_mask → FlexAttention 块对角因果(每用户只在自己序列内 + 做因果注意力,避免对 ~14000 长拼接序列做 O(S²) 稠密注意力,计算量砍数十倍)。 + - 否则 → 标准 SDPA(稠密 mask,数学等价、用于回退/对照)。 + """ + if extension is not None and extension.get("block_mask") is not None: + return flex_attention(q, k, v, block_mask=extension["block_mask"]) + if extension is not None and "mask" in extension: attn_mask = extension["mask"].to(device=q.device) else: @@ -369,6 +404,29 @@ class TopKGate(nn.Module): return topk_idx, topk_score, probs +def _smoe_forward_loop(moe, x): + """原始逐 expert 循环实现(保留作数值等价对照/回退)。""" + B, S, D = x.shape + topk_idx, topk_score, probs = moe.gate(x) + out = torch.zeros_like(x) + x_flat = x.reshape(-1, D) + idx_flat = topk_idx.reshape(-1, moe.k) + score_flat = topk_score.reshape(-1, moe.k) + out_flat = out.reshape(-1, D) + for i in range(moe.num_experts): + mask = (idx_flat == i) + token_idx, k_idx = mask.nonzero(as_tuple=True) + if token_idx.numel() == 0: + continue + selected_x = x_flat[token_idx] + expert_out = moe.experts[i](selected_x) + weight = score_flat[token_idx, k_idx].unsqueeze(-1) + out_flat[token_idx] += expert_out * weight + importance = probs.sum(dim=(0, 1)) + moe_loss = (importance.std() / (importance.mean() + 1e-6)) + return out, moe_loss + + class SMoE(nn.Module): def __init__(self, d_model, dim_ff, num_experts, k=2): super().__init__() @@ -380,37 +438,43 @@ class SMoE(nn.Module): ]) self.gate = TopKGate(d_model, num_experts, k=k) + self._stacked = False + + def _stack_weights(self): + """把各 expert 的 fc1/fc2 权重堆叠成单一张量,供批量 matmul。 + 延迟到首次 forward 调用:此时已完成 expert 合并与 half()/to(device)。""" + self.register_buffer("W1", torch.stack([e.fc1.weight for e in self.experts]).contiguous()) # [E,F,D] + self.register_buffer("b1", torch.stack([e.fc1.bias for e in self.experts]).contiguous()) # [E,F] + self.register_buffer("W2", torch.stack([e.fc2.weight for e in self.experts]).contiguous()) # [E,D,F] + self.register_buffer("b2", torch.stack([e.fc2.bias for e in self.experts]).contiguous()) # [E,D] + self._stacked = True def forward(self, x): # x: [B,S,D] - B, S, D = x.shape + if not CONFIG.get("vectorize_moe", True): + return _smoe_forward_loop(self, x) + if not self._stacked: + self._stack_weights() + + B, S, D = x.shape topk_idx, topk_score, probs = self.gate(x) - out = torch.zeros_like(x) + xf = x.reshape(-1, D) # [N, D] + # 稠密计算所有 expert(GPU 友好、无 Python 循环/同步/gather-scatter): + h = torch.einsum("nd,efd->enf", xf, self.W1) + self.b1.unsqueeze(1) # [E,N,F] + h = F.relu(h) + o = torch.einsum("enf,edf->end", h, self.W2) + self.b2.unsqueeze(1) # [E,N,D] - # flatten - x_flat = x.reshape(-1, D) # [B*S, D] - idx_flat = topk_idx.reshape(-1, self.k) # [B*S, k] - score_flat = topk_score.reshape(-1, self.k) - out_flat = out.reshape(-1, D) # 提前 reshape,避免循环内重复 + # 按每个 token 的 top-k 选取并加权(与逐 expert 循环数学等价) + o = o.permute(1, 0, 2) # [N, E, D] + idx = topk_idx.reshape(-1, self.k) # [N, k] + sc = topk_score.reshape(-1, self.k) # [N, k] + sel = torch.gather(o, 1, idx.unsqueeze(-1).expand(-1, -1, D)) # [N, k, D] + out = (sel * sc.unsqueeze(-1)).sum(dim=1).reshape(B, S, D) - for i in range(self.num_experts): - # 找到被路由到 expert i 的 token - mask = (idx_flat == i) # [B*S, k] - - token_idx, k_idx = mask.nonzero(as_tuple=True) - if token_idx.numel() == 0: - continue - - selected_x = x_flat[token_idx] # [N, D] - expert_out = self.experts[i](selected_x) # [N, D] - weight = score_flat[token_idx, k_idx].unsqueeze(-1) - out_flat[token_idx] += expert_out * weight - - importance = probs.sum(dim=(0,1)) # [E] + importance = probs.sum(dim=(0, 1)) # [E] moe_loss = (importance.std() / (importance.mean() + 1e-6)) - return out, moe_loss @@ -481,13 +545,28 @@ class CTRModel(nn.Module): out_mask = torch.tril((a == 0).to(torch.int32)).bool() return out_mask + def build_block_mask(self, user_offsets, S): + """FlexAttention 块对角因果 mask:q 只能 attend 同一用户且 kv<=q 的位置。""" + lengths = (user_offsets[1:] - user_offsets[:-1]).view(-1) + device = user_offsets.device + doc_id = torch.repeat_interleave( + torch.arange(lengths.numel(), device=device), lengths) + + def mask_mod(b, h, q_idx, kv_idx): + return (q_idx >= kv_idx) & (doc_id[q_idx] == doc_id[kv_idx]) + + return create_block_mask(mask_mod, B=None, H=None, Q_LEN=S, KV_LEN=S, device=device) + def forward(self, batch): seq_input = self.rep_encoder(batch) - seq_mask = self.get_sequence_causal_mask(batch["user_offsets"]) - encoder_output, moe_loss = self.seq_encoder( - x=seq_input, - extension={"mask": seq_mask.unsqueeze(0).unsqueeze(0)}, - ) + user_offsets = batch["user_offsets"] + if _use_flex(seq_input.device): + S = seq_input.shape[0] # rep_encoder 输出 [S, D],S=总 token 数 + extension = {"block_mask": self.build_block_mask(user_offsets, S)} + else: + seq_mask = self.get_sequence_causal_mask(user_offsets) + extension = {"mask": seq_mask.unsqueeze(0).unsqueeze(0)} + encoder_output, moe_loss = self.seq_encoder(x=seq_input, extension=extension) encoder_output = encoder_output.squeeze(0) pred = self.linear(encoder_output) pred_logits = torch.clamp(pred, min=-15.0, max=15.0) @@ -570,8 +649,19 @@ def load_model(ckpt_path, device='cuda:0'): model.to(dev) model.eval() - print(f"[INFO] Model ready. Device: {dev}") + use_flex = _use_flex(dev) + print(f"[INFO] attention={'FlexAttention(block-causal)' if use_flex else 'SDPA(dense)'}, " + f"moe={'dense' if CONFIG.get('vectorize_moe', True) else 'loop'}") + + if CONFIG.get("compile", False): + try: + model = torch.compile(model, dynamic=True) + print("[INFO] torch.compile enabled (dynamic=True)") + except Exception as e: + print(f"[WARNING] torch.compile failed ({e}), running eager") + + print(f"[INFO] Model ready. Device: {dev}") return model, dev diff --git a/代码/code/tests/test_equiv.py b/代码/code/tests/test_equiv.py new file mode 100644 index 0000000..522d3e6 --- /dev/null +++ b/代码/code/tests/test_equiv.py @@ -0,0 +1,92 @@ +"""Phase B 数值等价测试:新实现 vs 原实现。子进程跑: + + %cd /home/aistudio/code + !python tests/test_equiv.py + +- MoE 稠密向量化 vs 原逐 expert 循环(CPU/GPU 都可,FP32) +- FlexAttention 块对角因果 vs 稠密 SDPA(需 CUDA SM80+,否则自动跳过) +""" +import os +import sys + +# baseline 把依赖装在 --target 目录;import 前补 sys.path +for _p in ("/home/aistudio/external-libraries", "/home/aistudio/libraries", + os.path.abspath("../libraries"), os.path.abspath("./libraries")): + if os.path.isdir(_p) and _p not in sys.path: + sys.path.insert(0, _p) +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +import torch +import torch.nn.functional as F +import infer + + +def _offsets(lengths, device): + offs = [0] + for L in lengths: + offs.append(offs[-1] + L) + return torch.tensor(offs, dtype=torch.long, device=device) + + +def _dense_causal_mask(offs): + """同用户 + 因果(tril),与 CTRModel.get_sequence_causal_mask 语义一致。""" + lengths = (offs[1:] - offs[:-1]).view(-1) + idx = torch.repeat_interleave( + torch.arange(lengths.numel(), device=offs.device), lengths) + same = idx.view(1, -1) == idx.view(-1, 1) + causal = torch.tril(torch.ones_like(same, dtype=torch.bool)) + return same & causal + + +def _block_mask(offs, S): + lengths = (offs[1:] - offs[:-1]).view(-1) + doc_id = torch.repeat_interleave( + torch.arange(lengths.numel(), device=offs.device), lengths) + + def mask_mod(b, h, q_idx, kv_idx): + return (q_idx >= kv_idx) & (doc_id[q_idx] == doc_id[kv_idx]) + + return infer.create_block_mask(mask_mod, B=None, H=None, Q_LEN=S, KV_LEN=S, + device=offs.device) + + +def test_moe_dense_matches_loop(): + torch.manual_seed(0) + dev = "cuda" if torch.cuda.is_available() else "cpu" + moe = infer.SMoE(d_model=512, dim_ff=1024, num_experts=8, k=2).to(dev).eval() + x = torch.randn(1, 200, 512, device=dev) + with torch.no_grad(): + ref, _ = infer._smoe_forward_loop(moe, x) + infer.CONFIG["vectorize_moe"] = True + new, _ = moe(x) + err = (ref - new).abs().max().item() + assert torch.allclose(ref, new, atol=1e-4, rtol=1e-4), f"MoE 不等价 max err={err:.3e}" + print(f"[PASS] MoE 稠密向量化 == 逐expert循环 (max err={err:.2e}, dev={dev})") + + +def test_flex_matches_dense_attention(): + ok = (torch.cuda.is_available() and infer._HAS_FLEX + and torch.cuda.get_device_capability()[0] >= 8) + if not ok: + print("[SKIP] FlexAttention 等价测试(需 CUDA SM80+,当前环境不满足)") + return + torch.manual_seed(0) + dev = "cuda" + H, Dh = 8, 64 + offs = _offsets([10, 25, 7, 40, 18], dev) + S = int(offs[-1]) + q = torch.randn(1, H, S, Dh, device=dev) + k = torch.randn(1, H, S, Dh, device=dev) + v = torch.randn(1, H, S, Dh, device=dev) + with torch.no_grad(): + dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]}) + flex = infer.scaled_dot_product(q, k, v, {"block_mask": _block_mask(offs, S)}) + err = (dense - flex).abs().max().item() + assert torch.allclose(dense, flex, atol=2e-2, rtol=2e-2), f"Flex 不等价 max err={err:.3e}" + print(f"[PASS] FlexAttention 块对角 == 稠密SDPA (max err={err:.2e})") + + +if __name__ == "__main__": + test_moe_dense_matches_loop() + test_flex_matches_dense_attention() + print("[DONE] 等价测试结束")