feat(Phase B): FlexAttention 块对角注意力 + MoE 稠密向量化

- scaled_dot_product 分发:block_mask->FlexAttention(每用户仅自身序列内因果,
  避免对~14000长拼接序列做O(S²)稠密注意力);否则SDPA稠密(回退/对照)。
- CTRModel.build_block_mask 构造块对角因果mask;_use_flex 在SM80+自动启用。
- SMoE 稠密向量化(einsum批量算所有expert后按top-k gather),消除Python循环/同步;
  保留 _smoe_forward_loop 作数值等价对照。CONFIG.vectorize_moe 可切。
- load_model 加可选 torch.compile。
- tests/test_equiv.py:MoE稠密vs循环、Flex vs稠密SDPA 数值等价(无pytest依赖)。
- bench.py 加 --attn/--moe/--compile 便于A800上对比测速。

需 A800(SM80) 实测;CPU/V100 自动回退 SDPA。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-14 23:30:59 +08:00
parent 0a971e67ac
commit c1d8b91fb2
3 changed files with 222 additions and 29 deletions
+11
View File
@@ -250,6 +250,11 @@ def _parse_args():
help="逗号分隔的 keep_fp32_modules,如 linear,rep_encoder.input_norm") help="逗号分隔的 keep_fp32_modules,如 linear,rep_encoder.input_norm")
ap.add_argument("--feasign-none", action="store_true", ap.add_argument("--feasign-none", action="store_true",
help="不截断特征(max_feasign_per_slot=None") help="不截断特征(max_feasign_per_slot=None")
ap.add_argument("--attn", choices=["auto", "flex", "sdpa"], default=None,
help="注意力实现:flex=块对角FlexAttention, sdpa=稠密(原), auto=SM80自动")
ap.add_argument("--moe", choices=["dense", "loop"], default=None,
help="MoE实现:dense=向量化(新), loop=逐expert循环(原)")
ap.add_argument("--compile", action="store_true", help="开启 torch.compile")
ap.add_argument("--rebuild", action="store_true", help="强制重建过滤缓存") ap.add_argument("--rebuild", action="store_true", help="强制重建过滤缓存")
return ap.parse_args() return ap.parse_args()
@@ -273,5 +278,11 @@ if __name__ == "__main__":
cfg["merge_threshold"] = a.merge_th cfg["merge_threshold"] = a.merge_th
if a.keep is not None: if a.keep is not None:
cfg["keep_fp32_modules"] = tuple(x for x in a.keep.split(",") if x) cfg["keep_fp32_modules"] = tuple(x for x in a.keep.split(",") if x)
if a.attn is not None:
cfg["use_flex_attn"] = {"auto": "auto", "flex": True, "sdpa": False}[a.attn]
if a.moe is not None:
cfg["vectorize_moe"] = (a.moe == "dense")
if a.compile:
cfg["compile"] = True
mf = None if a.feasign_none else {1: 2} mf = None if a.feasign_none else {1: 2}
run_once(cfg, batch_size=a.bs, max_batches=a.smoke, max_feasign_per_slot=mf, rebuild=a.rebuild) run_once(cfg, batch_size=a.bs, max_batches=a.smoke, max_feasign_per_slot=mf, rebuild=a.rebuild)
+118 -28
View File
@@ -17,6 +17,15 @@ import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm from tqdm import tqdm
# FlexAttention(块对角因果注意力,需 PyTorch 2.5+ 且 GPU 计算能力 >= 8.0 / Ampere
try:
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
_HAS_FLEX = True
except Exception:
flex_attention = None
create_block_mask = None
_HAS_FLEX = False
# ============================================================ # ============================================================
# 实验配置开关板 # 实验配置开关板
@@ -31,9 +40,28 @@ CONFIG = {
"signid_mode": "clamp", # "clamp" 或 "modulo":处理超界 sign id 的方式 "signid_mode": "clamp", # "clamp" 或 "modulo":处理超界 sign id 的方式
"sync_timing": False, # bench 里设 True,做 torch.cuda.synchronize 真实计时 "sync_timing": False, # bench 里设 True,做 torch.cuda.synchronize 真实计时
"filter_test_users": True, # 只处理含测试样本的用户(跳过会被丢弃的用户,省算力) "filter_test_users": True, # 只处理含测试样本的用户(跳过会被丢弃的用户,省算力)
"use_flex_attn": "auto", # "auto"(SM80+用flex,否则SDPA) / True / False
"vectorize_moe": True, # True=稠密向量化MoE(无Python循环/同步);False=原逐expert循环
"compile": False, # 是否 torch.compile(图理干净后再开)
} }
def _use_flex(device):
"""决定是否用 FlexAttentionauto 模式下仅在 SM80+Ampere/A800)启用。"""
mode = CONFIG.get("use_flex_attn", "auto")
if not _HAS_FLEX or mode is False:
return False
if mode is True:
return True
if device is not None and device.type == "cuda":
try:
major, _ = torch.cuda.get_device_capability(device)
return major >= 8
except Exception:
return False
return False
def _force_fp32_io(module): def _force_fp32_io(module):
"""让某个模块在 FP16 模型里以 FP32 计算:输入转 FP32、输出转回 FP16。 """让某个模块在 FP16 模型里以 FP32 计算:输入转 FP32、输出转回 FP16。
用于 keep_fp32_modules 指定的精度敏感层(如最终输出头、LayerNorm)。""" 用于 keep_fp32_modules 指定的精度敏感层(如最终输出头、LayerNorm)。"""
@@ -324,7 +352,14 @@ class RepEncoder(nn.Module):
def scaled_dot_product(q, k, v, extension): def scaled_dot_product(q, k, v, extension):
"""使用 PyTorch SDPA 后端(自动启用 Flash Attention / Memory Efficient Attention""" """注意力分发:
- 若 extension 带 block_mask → FlexAttention 块对角因果(每用户只在自己序列内
做因果注意力,避免对 ~14000 长拼接序列做 O(S²) 稠密注意力,计算量砍数十倍)。
- 否则 → 标准 SDPA(稠密 mask,数学等价、用于回退/对照)。
"""
if extension is not None and extension.get("block_mask") is not None:
return flex_attention(q, k, v, block_mask=extension["block_mask"])
if extension is not None and "mask" in extension: if extension is not None and "mask" in extension:
attn_mask = extension["mask"].to(device=q.device) attn_mask = extension["mask"].to(device=q.device)
else: else:
@@ -369,6 +404,29 @@ class TopKGate(nn.Module):
return topk_idx, topk_score, probs return topk_idx, topk_score, probs
def _smoe_forward_loop(moe, x):
"""原始逐 expert 循环实现(保留作数值等价对照/回退)。"""
B, S, D = x.shape
topk_idx, topk_score, probs = moe.gate(x)
out = torch.zeros_like(x)
x_flat = x.reshape(-1, D)
idx_flat = topk_idx.reshape(-1, moe.k)
score_flat = topk_score.reshape(-1, moe.k)
out_flat = out.reshape(-1, D)
for i in range(moe.num_experts):
mask = (idx_flat == i)
token_idx, k_idx = mask.nonzero(as_tuple=True)
if token_idx.numel() == 0:
continue
selected_x = x_flat[token_idx]
expert_out = moe.experts[i](selected_x)
weight = score_flat[token_idx, k_idx].unsqueeze(-1)
out_flat[token_idx] += expert_out * weight
importance = probs.sum(dim=(0, 1))
moe_loss = (importance.std() / (importance.mean() + 1e-6))
return out, moe_loss
class SMoE(nn.Module): class SMoE(nn.Module):
def __init__(self, d_model, dim_ff, num_experts, k=2): def __init__(self, d_model, dim_ff, num_experts, k=2):
super().__init__() super().__init__()
@@ -380,37 +438,43 @@ class SMoE(nn.Module):
]) ])
self.gate = TopKGate(d_model, num_experts, k=k) self.gate = TopKGate(d_model, num_experts, k=k)
self._stacked = False
def _stack_weights(self):
"""把各 expert 的 fc1/fc2 权重堆叠成单一张量,供批量 matmul。
延迟到首次 forward 调用:此时已完成 expert 合并与 half()/to(device)。"""
self.register_buffer("W1", torch.stack([e.fc1.weight for e in self.experts]).contiguous()) # [E,F,D]
self.register_buffer("b1", torch.stack([e.fc1.bias for e in self.experts]).contiguous()) # [E,F]
self.register_buffer("W2", torch.stack([e.fc2.weight for e in self.experts]).contiguous()) # [E,D,F]
self.register_buffer("b2", torch.stack([e.fc2.bias for e in self.experts]).contiguous()) # [E,D]
self._stacked = True
def forward(self, x): def forward(self, x):
# x: [B,S,D] # x: [B,S,D]
B, S, D = x.shape if not CONFIG.get("vectorize_moe", True):
return _smoe_forward_loop(self, x)
if not self._stacked:
self._stack_weights()
B, S, D = x.shape
topk_idx, topk_score, probs = self.gate(x) topk_idx, topk_score, probs = self.gate(x)
out = torch.zeros_like(x) xf = x.reshape(-1, D) # [N, D]
# 稠密计算所有 expertGPU 友好、无 Python 循环/同步/gather-scatter):
h = torch.einsum("nd,efd->enf", xf, self.W1) + self.b1.unsqueeze(1) # [E,N,F]
h = F.relu(h)
o = torch.einsum("enf,edf->end", h, self.W2) + self.b2.unsqueeze(1) # [E,N,D]
# flatten # 按每个 token 的 top-k 选取并加权(与逐 expert 循环数学等价)
x_flat = x.reshape(-1, D) # [B*S, D] o = o.permute(1, 0, 2) # [N, E, D]
idx_flat = topk_idx.reshape(-1, self.k) # [B*S, k] idx = topk_idx.reshape(-1, self.k) # [N, k]
score_flat = topk_score.reshape(-1, self.k) sc = topk_score.reshape(-1, self.k) # [N, k]
out_flat = out.reshape(-1, D) # 提前 reshape,避免循环内重复 sel = torch.gather(o, 1, idx.unsqueeze(-1).expand(-1, -1, D)) # [N, k, D]
out = (sel * sc.unsqueeze(-1)).sum(dim=1).reshape(B, S, D)
for i in range(self.num_experts):
# 找到被路由到 expert i 的 token
mask = (idx_flat == i) # [B*S, k]
token_idx, k_idx = mask.nonzero(as_tuple=True)
if token_idx.numel() == 0:
continue
selected_x = x_flat[token_idx] # [N, D]
expert_out = self.experts[i](selected_x) # [N, D]
weight = score_flat[token_idx, k_idx].unsqueeze(-1)
out_flat[token_idx] += expert_out * weight
importance = probs.sum(dim=(0, 1)) # [E] importance = probs.sum(dim=(0, 1)) # [E]
moe_loss = (importance.std() / (importance.mean() + 1e-6)) moe_loss = (importance.std() / (importance.mean() + 1e-6))
return out, moe_loss return out, moe_loss
@@ -481,13 +545,28 @@ class CTRModel(nn.Module):
out_mask = torch.tril((a == 0).to(torch.int32)).bool() out_mask = torch.tril((a == 0).to(torch.int32)).bool()
return out_mask return out_mask
def build_block_mask(self, user_offsets, S):
"""FlexAttention 块对角因果 maskq 只能 attend 同一用户且 kv<=q 的位置。"""
lengths = (user_offsets[1:] - user_offsets[:-1]).view(-1)
device = user_offsets.device
doc_id = torch.repeat_interleave(
torch.arange(lengths.numel(), device=device), lengths)
def mask_mod(b, h, q_idx, kv_idx):
return (q_idx >= kv_idx) & (doc_id[q_idx] == doc_id[kv_idx])
return create_block_mask(mask_mod, B=None, H=None, Q_LEN=S, KV_LEN=S, device=device)
def forward(self, batch): def forward(self, batch):
seq_input = self.rep_encoder(batch) seq_input = self.rep_encoder(batch)
seq_mask = self.get_sequence_causal_mask(batch["user_offsets"]) user_offsets = batch["user_offsets"]
encoder_output, moe_loss = self.seq_encoder( if _use_flex(seq_input.device):
x=seq_input, S = seq_input.shape[0] # rep_encoder 输出 [S, D]S=总 token 数
extension={"mask": seq_mask.unsqueeze(0).unsqueeze(0)}, extension = {"block_mask": self.build_block_mask(user_offsets, S)}
) else:
seq_mask = self.get_sequence_causal_mask(user_offsets)
extension = {"mask": seq_mask.unsqueeze(0).unsqueeze(0)}
encoder_output, moe_loss = self.seq_encoder(x=seq_input, extension=extension)
encoder_output = encoder_output.squeeze(0) encoder_output = encoder_output.squeeze(0)
pred = self.linear(encoder_output) pred = self.linear(encoder_output)
pred_logits = torch.clamp(pred, min=-15.0, max=15.0) pred_logits = torch.clamp(pred, min=-15.0, max=15.0)
@@ -570,8 +649,19 @@ def load_model(ckpt_path, device='cuda:0'):
model.to(dev) model.to(dev)
model.eval() model.eval()
print(f"[INFO] Model ready. Device: {dev}")
use_flex = _use_flex(dev)
print(f"[INFO] attention={'FlexAttention(block-causal)' if use_flex else 'SDPA(dense)'}, "
f"moe={'dense' if CONFIG.get('vectorize_moe', True) else 'loop'}")
if CONFIG.get("compile", False):
try:
model = torch.compile(model, dynamic=True)
print("[INFO] torch.compile enabled (dynamic=True)")
except Exception as e:
print(f"[WARNING] torch.compile failed ({e}), running eager")
print(f"[INFO] Model ready. Device: {dev}")
return model, dev return model, dev
+92
View File
@@ -0,0 +1,92 @@
"""Phase B 数值等价测试:新实现 vs 原实现。子进程跑:
%cd /home/aistudio/code
!python tests/test_equiv.py
- MoE 稠密向量化 vs 原逐 expert 循环(CPU/GPU 都可,FP32
- FlexAttention 块对角因果 vs 稠密 SDPA(需 CUDA SM80+,否则自动跳过)
"""
import os
import sys
# baseline 把依赖装在 --target 目录;import 前补 sys.path
for _p in ("/home/aistudio/external-libraries", "/home/aistudio/libraries",
os.path.abspath("../libraries"), os.path.abspath("./libraries")):
if os.path.isdir(_p) and _p not in sys.path:
sys.path.insert(0, _p)
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import torch
import torch.nn.functional as F
import infer
def _offsets(lengths, device):
offs = [0]
for L in lengths:
offs.append(offs[-1] + L)
return torch.tensor(offs, dtype=torch.long, device=device)
def _dense_causal_mask(offs):
"""同用户 + 因果(tril),与 CTRModel.get_sequence_causal_mask 语义一致。"""
lengths = (offs[1:] - offs[:-1]).view(-1)
idx = torch.repeat_interleave(
torch.arange(lengths.numel(), device=offs.device), lengths)
same = idx.view(1, -1) == idx.view(-1, 1)
causal = torch.tril(torch.ones_like(same, dtype=torch.bool))
return same & causal
def _block_mask(offs, S):
lengths = (offs[1:] - offs[:-1]).view(-1)
doc_id = torch.repeat_interleave(
torch.arange(lengths.numel(), device=offs.device), lengths)
def mask_mod(b, h, q_idx, kv_idx):
return (q_idx >= kv_idx) & (doc_id[q_idx] == doc_id[kv_idx])
return infer.create_block_mask(mask_mod, B=None, H=None, Q_LEN=S, KV_LEN=S,
device=offs.device)
def test_moe_dense_matches_loop():
torch.manual_seed(0)
dev = "cuda" if torch.cuda.is_available() else "cpu"
moe = infer.SMoE(d_model=512, dim_ff=1024, num_experts=8, k=2).to(dev).eval()
x = torch.randn(1, 200, 512, device=dev)
with torch.no_grad():
ref, _ = infer._smoe_forward_loop(moe, x)
infer.CONFIG["vectorize_moe"] = True
new, _ = moe(x)
err = (ref - new).abs().max().item()
assert torch.allclose(ref, new, atol=1e-4, rtol=1e-4), f"MoE 不等价 max err={err:.3e}"
print(f"[PASS] MoE 稠密向量化 == 逐expert循环 (max err={err:.2e}, dev={dev})")
def test_flex_matches_dense_attention():
ok = (torch.cuda.is_available() and infer._HAS_FLEX
and torch.cuda.get_device_capability()[0] >= 8)
if not ok:
print("[SKIP] FlexAttention 等价测试(需 CUDA SM80+,当前环境不满足)")
return
torch.manual_seed(0)
dev = "cuda"
H, Dh = 8, 64
offs = _offsets([10, 25, 7, 40, 18], dev)
S = int(offs[-1])
q = torch.randn(1, H, S, Dh, device=dev)
k = torch.randn(1, H, S, Dh, device=dev)
v = torch.randn(1, H, S, Dh, device=dev)
with torch.no_grad():
dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]})
flex = infer.scaled_dot_product(q, k, v, {"block_mask": _block_mask(offs, S)})
err = (dense - flex).abs().max().item()
assert torch.allclose(dense, flex, atol=2e-2, rtol=2e-2), f"Flex 不等价 max err={err:.3e}"
print(f"[PASS] FlexAttention 块对角 == 稠密SDPA (max err={err:.2e})")
if __name__ == "__main__":
test_moe_dense_matches_loop()
test_flex_matches_dense_attention()
print("[DONE] 等价测试结束")