feat(Phase B): FlexAttention 块对角注意力 + MoE 稠密向量化

- scaled_dot_product 分发:block_mask->FlexAttention(每用户仅自身序列内因果,
  避免对~14000长拼接序列做O(S²)稠密注意力);否则SDPA稠密(回退/对照)。
- CTRModel.build_block_mask 构造块对角因果mask;_use_flex 在SM80+自动启用。
- SMoE 稠密向量化(einsum批量算所有expert后按top-k gather),消除Python循环/同步;
  保留 _smoe_forward_loop 作数值等价对照。CONFIG.vectorize_moe 可切。
- load_model 加可选 torch.compile。
- tests/test_equiv.py:MoE稠密vs循环、Flex vs稠密SDPA 数值等价(无pytest依赖)。
- bench.py 加 --attn/--moe/--compile 便于A800上对比测速。

需 A800(SM80) 实测;CPU/V100 自动回退 SDPA。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-14 23:30:59 +08:00
parent 0a971e67ac
commit c1d8b91fb2
3 changed files with 222 additions and 29 deletions
+119 -29
View File
@@ -17,6 +17,15 @@ import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
# FlexAttention(块对角因果注意力,需 PyTorch 2.5+ 且 GPU 计算能力 >= 8.0 / Ampere
try:
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
_HAS_FLEX = True
except Exception:
flex_attention = None
create_block_mask = None
_HAS_FLEX = False
# ============================================================
# 实验配置开关板
@@ -31,9 +40,28 @@ CONFIG = {
"signid_mode": "clamp", # "clamp" 或 "modulo":处理超界 sign id 的方式
"sync_timing": False, # bench 里设 True,做 torch.cuda.synchronize 真实计时
"filter_test_users": True, # 只处理含测试样本的用户(跳过会被丢弃的用户,省算力)
"use_flex_attn": "auto", # "auto"(SM80+用flex,否则SDPA) / True / False
"vectorize_moe": True, # True=稠密向量化MoE(无Python循环/同步);False=原逐expert循环
"compile": False, # 是否 torch.compile(图理干净后再开)
}
def _use_flex(device):
"""决定是否用 FlexAttentionauto 模式下仅在 SM80+Ampere/A800)启用。"""
mode = CONFIG.get("use_flex_attn", "auto")
if not _HAS_FLEX or mode is False:
return False
if mode is True:
return True
if device is not None and device.type == "cuda":
try:
major, _ = torch.cuda.get_device_capability(device)
return major >= 8
except Exception:
return False
return False
def _force_fp32_io(module):
"""让某个模块在 FP16 模型里以 FP32 计算:输入转 FP32、输出转回 FP16。
用于 keep_fp32_modules 指定的精度敏感层(如最终输出头、LayerNorm)。"""
@@ -324,7 +352,14 @@ class RepEncoder(nn.Module):
def scaled_dot_product(q, k, v, extension):
"""使用 PyTorch SDPA 后端(自动启用 Flash Attention / Memory Efficient Attention"""
"""注意力分发:
- 若 extension 带 block_mask → FlexAttention 块对角因果(每用户只在自己序列内
做因果注意力,避免对 ~14000 长拼接序列做 O(S²) 稠密注意力,计算量砍数十倍)。
- 否则 → 标准 SDPA(稠密 mask,数学等价、用于回退/对照)。
"""
if extension is not None and extension.get("block_mask") is not None:
return flex_attention(q, k, v, block_mask=extension["block_mask"])
if extension is not None and "mask" in extension:
attn_mask = extension["mask"].to(device=q.device)
else:
@@ -369,6 +404,29 @@ class TopKGate(nn.Module):
return topk_idx, topk_score, probs
def _smoe_forward_loop(moe, x):
"""原始逐 expert 循环实现(保留作数值等价对照/回退)。"""
B, S, D = x.shape
topk_idx, topk_score, probs = moe.gate(x)
out = torch.zeros_like(x)
x_flat = x.reshape(-1, D)
idx_flat = topk_idx.reshape(-1, moe.k)
score_flat = topk_score.reshape(-1, moe.k)
out_flat = out.reshape(-1, D)
for i in range(moe.num_experts):
mask = (idx_flat == i)
token_idx, k_idx = mask.nonzero(as_tuple=True)
if token_idx.numel() == 0:
continue
selected_x = x_flat[token_idx]
expert_out = moe.experts[i](selected_x)
weight = score_flat[token_idx, k_idx].unsqueeze(-1)
out_flat[token_idx] += expert_out * weight
importance = probs.sum(dim=(0, 1))
moe_loss = (importance.std() / (importance.mean() + 1e-6))
return out, moe_loss
class SMoE(nn.Module):
def __init__(self, d_model, dim_ff, num_experts, k=2):
super().__init__()
@@ -380,37 +438,43 @@ class SMoE(nn.Module):
])
self.gate = TopKGate(d_model, num_experts, k=k)
self._stacked = False
def _stack_weights(self):
"""把各 expert 的 fc1/fc2 权重堆叠成单一张量,供批量 matmul。
延迟到首次 forward 调用:此时已完成 expert 合并与 half()/to(device)。"""
self.register_buffer("W1", torch.stack([e.fc1.weight for e in self.experts]).contiguous()) # [E,F,D]
self.register_buffer("b1", torch.stack([e.fc1.bias for e in self.experts]).contiguous()) # [E,F]
self.register_buffer("W2", torch.stack([e.fc2.weight for e in self.experts]).contiguous()) # [E,D,F]
self.register_buffer("b2", torch.stack([e.fc2.bias for e in self.experts]).contiguous()) # [E,D]
self._stacked = True
def forward(self, x):
# x: [B,S,D]
B, S, D = x.shape
if not CONFIG.get("vectorize_moe", True):
return _smoe_forward_loop(self, x)
if not self._stacked:
self._stack_weights()
B, S, D = x.shape
topk_idx, topk_score, probs = self.gate(x)
out = torch.zeros_like(x)
xf = x.reshape(-1, D) # [N, D]
# 稠密计算所有 expertGPU 友好、无 Python 循环/同步/gather-scatter):
h = torch.einsum("nd,efd->enf", xf, self.W1) + self.b1.unsqueeze(1) # [E,N,F]
h = F.relu(h)
o = torch.einsum("enf,edf->end", h, self.W2) + self.b2.unsqueeze(1) # [E,N,D]
# flatten
x_flat = x.reshape(-1, D) # [B*S, D]
idx_flat = topk_idx.reshape(-1, self.k) # [B*S, k]
score_flat = topk_score.reshape(-1, self.k)
out_flat = out.reshape(-1, D) # 提前 reshape,避免循环内重复
# 按每个 token 的 top-k 选取并加权(与逐 expert 循环数学等价)
o = o.permute(1, 0, 2) # [N, E, D]
idx = topk_idx.reshape(-1, self.k) # [N, k]
sc = topk_score.reshape(-1, self.k) # [N, k]
sel = torch.gather(o, 1, idx.unsqueeze(-1).expand(-1, -1, D)) # [N, k, D]
out = (sel * sc.unsqueeze(-1)).sum(dim=1).reshape(B, S, D)
for i in range(self.num_experts):
# 找到被路由到 expert i 的 token
mask = (idx_flat == i) # [B*S, k]
token_idx, k_idx = mask.nonzero(as_tuple=True)
if token_idx.numel() == 0:
continue
selected_x = x_flat[token_idx] # [N, D]
expert_out = self.experts[i](selected_x) # [N, D]
weight = score_flat[token_idx, k_idx].unsqueeze(-1)
out_flat[token_idx] += expert_out * weight
importance = probs.sum(dim=(0,1)) # [E]
importance = probs.sum(dim=(0, 1)) # [E]
moe_loss = (importance.std() / (importance.mean() + 1e-6))
return out, moe_loss
@@ -481,13 +545,28 @@ class CTRModel(nn.Module):
out_mask = torch.tril((a == 0).to(torch.int32)).bool()
return out_mask
def build_block_mask(self, user_offsets, S):
"""FlexAttention 块对角因果 maskq 只能 attend 同一用户且 kv<=q 的位置。"""
lengths = (user_offsets[1:] - user_offsets[:-1]).view(-1)
device = user_offsets.device
doc_id = torch.repeat_interleave(
torch.arange(lengths.numel(), device=device), lengths)
def mask_mod(b, h, q_idx, kv_idx):
return (q_idx >= kv_idx) & (doc_id[q_idx] == doc_id[kv_idx])
return create_block_mask(mask_mod, B=None, H=None, Q_LEN=S, KV_LEN=S, device=device)
def forward(self, batch):
seq_input = self.rep_encoder(batch)
seq_mask = self.get_sequence_causal_mask(batch["user_offsets"])
encoder_output, moe_loss = self.seq_encoder(
x=seq_input,
extension={"mask": seq_mask.unsqueeze(0).unsqueeze(0)},
)
user_offsets = batch["user_offsets"]
if _use_flex(seq_input.device):
S = seq_input.shape[0] # rep_encoder 输出 [S, D]S=总 token 数
extension = {"block_mask": self.build_block_mask(user_offsets, S)}
else:
seq_mask = self.get_sequence_causal_mask(user_offsets)
extension = {"mask": seq_mask.unsqueeze(0).unsqueeze(0)}
encoder_output, moe_loss = self.seq_encoder(x=seq_input, extension=extension)
encoder_output = encoder_output.squeeze(0)
pred = self.linear(encoder_output)
pred_logits = torch.clamp(pred, min=-15.0, max=15.0)
@@ -570,8 +649,19 @@ def load_model(ckpt_path, device='cuda:0'):
model.to(dev)
model.eval()
print(f"[INFO] Model ready. Device: {dev}")
use_flex = _use_flex(dev)
print(f"[INFO] attention={'FlexAttention(block-causal)' if use_flex else 'SDPA(dense)'}, "
f"moe={'dense' if CONFIG.get('vectorize_moe', True) else 'loop'}")
if CONFIG.get("compile", False):
try:
model = torch.compile(model, dynamic=True)
print("[INFO] torch.compile enabled (dynamic=True)")
except Exception as e:
print(f"[WARNING] torch.compile failed ({e}), running eager")
print(f"[INFO] Model ready. Device: {dev}")
return model, dev