feat(Phase B): FlexAttention 块对角注意力 + MoE 稠密向量化
- scaled_dot_product 分发:block_mask->FlexAttention(每用户仅自身序列内因果, 避免对~14000长拼接序列做O(S²)稠密注意力);否则SDPA稠密(回退/对照)。 - CTRModel.build_block_mask 构造块对角因果mask;_use_flex 在SM80+自动启用。 - SMoE 稠密向量化(einsum批量算所有expert后按top-k gather),消除Python循环/同步; 保留 _smoe_forward_loop 作数值等价对照。CONFIG.vectorize_moe 可切。 - load_model 加可选 torch.compile。 - tests/test_equiv.py:MoE稠密vs循环、Flex vs稠密SDPA 数值等价(无pytest依赖)。 - bench.py 加 --attn/--moe/--compile 便于A800上对比测速。 需 A800(SM80) 实测;CPU/V100 自动回退 SDPA。 Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
+119
-29
@@ -17,6 +17,15 @@ import torch.nn.functional as F
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
# FlexAttention(块对角因果注意力,需 PyTorch 2.5+ 且 GPU 计算能力 >= 8.0 / Ampere)
|
||||
try:
|
||||
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
||||
_HAS_FLEX = True
|
||||
except Exception:
|
||||
flex_attention = None
|
||||
create_block_mask = None
|
||||
_HAS_FLEX = False
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 实验配置开关板
|
||||
@@ -31,9 +40,28 @@ CONFIG = {
|
||||
"signid_mode": "clamp", # "clamp" 或 "modulo":处理超界 sign id 的方式
|
||||
"sync_timing": False, # bench 里设 True,做 torch.cuda.synchronize 真实计时
|
||||
"filter_test_users": True, # 只处理含测试样本的用户(跳过会被丢弃的用户,省算力)
|
||||
"use_flex_attn": "auto", # "auto"(SM80+用flex,否则SDPA) / True / False
|
||||
"vectorize_moe": True, # True=稠密向量化MoE(无Python循环/同步);False=原逐expert循环
|
||||
"compile": False, # 是否 torch.compile(图理干净后再开)
|
||||
}
|
||||
|
||||
|
||||
def _use_flex(device):
|
||||
"""决定是否用 FlexAttention:auto 模式下仅在 SM80+(Ampere/A800)启用。"""
|
||||
mode = CONFIG.get("use_flex_attn", "auto")
|
||||
if not _HAS_FLEX or mode is False:
|
||||
return False
|
||||
if mode is True:
|
||||
return True
|
||||
if device is not None and device.type == "cuda":
|
||||
try:
|
||||
major, _ = torch.cuda.get_device_capability(device)
|
||||
return major >= 8
|
||||
except Exception:
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def _force_fp32_io(module):
|
||||
"""让某个模块在 FP16 模型里以 FP32 计算:输入转 FP32、输出转回 FP16。
|
||||
用于 keep_fp32_modules 指定的精度敏感层(如最终输出头、LayerNorm)。"""
|
||||
@@ -324,7 +352,14 @@ class RepEncoder(nn.Module):
|
||||
|
||||
|
||||
def scaled_dot_product(q, k, v, extension):
|
||||
"""使用 PyTorch SDPA 后端(自动启用 Flash Attention / Memory Efficient Attention)"""
|
||||
"""注意力分发:
|
||||
- 若 extension 带 block_mask → FlexAttention 块对角因果(每用户只在自己序列内
|
||||
做因果注意力,避免对 ~14000 长拼接序列做 O(S²) 稠密注意力,计算量砍数十倍)。
|
||||
- 否则 → 标准 SDPA(稠密 mask,数学等价、用于回退/对照)。
|
||||
"""
|
||||
if extension is not None and extension.get("block_mask") is not None:
|
||||
return flex_attention(q, k, v, block_mask=extension["block_mask"])
|
||||
|
||||
if extension is not None and "mask" in extension:
|
||||
attn_mask = extension["mask"].to(device=q.device)
|
||||
else:
|
||||
@@ -369,6 +404,29 @@ class TopKGate(nn.Module):
|
||||
|
||||
return topk_idx, topk_score, probs
|
||||
|
||||
def _smoe_forward_loop(moe, x):
|
||||
"""原始逐 expert 循环实现(保留作数值等价对照/回退)。"""
|
||||
B, S, D = x.shape
|
||||
topk_idx, topk_score, probs = moe.gate(x)
|
||||
out = torch.zeros_like(x)
|
||||
x_flat = x.reshape(-1, D)
|
||||
idx_flat = topk_idx.reshape(-1, moe.k)
|
||||
score_flat = topk_score.reshape(-1, moe.k)
|
||||
out_flat = out.reshape(-1, D)
|
||||
for i in range(moe.num_experts):
|
||||
mask = (idx_flat == i)
|
||||
token_idx, k_idx = mask.nonzero(as_tuple=True)
|
||||
if token_idx.numel() == 0:
|
||||
continue
|
||||
selected_x = x_flat[token_idx]
|
||||
expert_out = moe.experts[i](selected_x)
|
||||
weight = score_flat[token_idx, k_idx].unsqueeze(-1)
|
||||
out_flat[token_idx] += expert_out * weight
|
||||
importance = probs.sum(dim=(0, 1))
|
||||
moe_loss = (importance.std() / (importance.mean() + 1e-6))
|
||||
return out, moe_loss
|
||||
|
||||
|
||||
class SMoE(nn.Module):
|
||||
def __init__(self, d_model, dim_ff, num_experts, k=2):
|
||||
super().__init__()
|
||||
@@ -380,37 +438,43 @@ class SMoE(nn.Module):
|
||||
])
|
||||
|
||||
self.gate = TopKGate(d_model, num_experts, k=k)
|
||||
self._stacked = False
|
||||
|
||||
def _stack_weights(self):
|
||||
"""把各 expert 的 fc1/fc2 权重堆叠成单一张量,供批量 matmul。
|
||||
延迟到首次 forward 调用:此时已完成 expert 合并与 half()/to(device)。"""
|
||||
self.register_buffer("W1", torch.stack([e.fc1.weight for e in self.experts]).contiguous()) # [E,F,D]
|
||||
self.register_buffer("b1", torch.stack([e.fc1.bias for e in self.experts]).contiguous()) # [E,F]
|
||||
self.register_buffer("W2", torch.stack([e.fc2.weight for e in self.experts]).contiguous()) # [E,D,F]
|
||||
self.register_buffer("b2", torch.stack([e.fc2.bias for e in self.experts]).contiguous()) # [E,D]
|
||||
self._stacked = True
|
||||
|
||||
def forward(self, x):
|
||||
# x: [B,S,D]
|
||||
B, S, D = x.shape
|
||||
if not CONFIG.get("vectorize_moe", True):
|
||||
return _smoe_forward_loop(self, x)
|
||||
|
||||
if not self._stacked:
|
||||
self._stack_weights()
|
||||
|
||||
B, S, D = x.shape
|
||||
topk_idx, topk_score, probs = self.gate(x)
|
||||
|
||||
out = torch.zeros_like(x)
|
||||
xf = x.reshape(-1, D) # [N, D]
|
||||
# 稠密计算所有 expert(GPU 友好、无 Python 循环/同步/gather-scatter):
|
||||
h = torch.einsum("nd,efd->enf", xf, self.W1) + self.b1.unsqueeze(1) # [E,N,F]
|
||||
h = F.relu(h)
|
||||
o = torch.einsum("enf,edf->end", h, self.W2) + self.b2.unsqueeze(1) # [E,N,D]
|
||||
|
||||
# flatten
|
||||
x_flat = x.reshape(-1, D) # [B*S, D]
|
||||
idx_flat = topk_idx.reshape(-1, self.k) # [B*S, k]
|
||||
score_flat = topk_score.reshape(-1, self.k)
|
||||
out_flat = out.reshape(-1, D) # 提前 reshape,避免循环内重复
|
||||
# 按每个 token 的 top-k 选取并加权(与逐 expert 循环数学等价)
|
||||
o = o.permute(1, 0, 2) # [N, E, D]
|
||||
idx = topk_idx.reshape(-1, self.k) # [N, k]
|
||||
sc = topk_score.reshape(-1, self.k) # [N, k]
|
||||
sel = torch.gather(o, 1, idx.unsqueeze(-1).expand(-1, -1, D)) # [N, k, D]
|
||||
out = (sel * sc.unsqueeze(-1)).sum(dim=1).reshape(B, S, D)
|
||||
|
||||
for i in range(self.num_experts):
|
||||
# 找到被路由到 expert i 的 token
|
||||
mask = (idx_flat == i) # [B*S, k]
|
||||
|
||||
token_idx, k_idx = mask.nonzero(as_tuple=True)
|
||||
if token_idx.numel() == 0:
|
||||
continue
|
||||
|
||||
selected_x = x_flat[token_idx] # [N, D]
|
||||
expert_out = self.experts[i](selected_x) # [N, D]
|
||||
weight = score_flat[token_idx, k_idx].unsqueeze(-1)
|
||||
out_flat[token_idx] += expert_out * weight
|
||||
|
||||
importance = probs.sum(dim=(0,1)) # [E]
|
||||
importance = probs.sum(dim=(0, 1)) # [E]
|
||||
moe_loss = (importance.std() / (importance.mean() + 1e-6))
|
||||
|
||||
return out, moe_loss
|
||||
|
||||
|
||||
@@ -481,13 +545,28 @@ class CTRModel(nn.Module):
|
||||
out_mask = torch.tril((a == 0).to(torch.int32)).bool()
|
||||
return out_mask
|
||||
|
||||
def build_block_mask(self, user_offsets, S):
|
||||
"""FlexAttention 块对角因果 mask:q 只能 attend 同一用户且 kv<=q 的位置。"""
|
||||
lengths = (user_offsets[1:] - user_offsets[:-1]).view(-1)
|
||||
device = user_offsets.device
|
||||
doc_id = torch.repeat_interleave(
|
||||
torch.arange(lengths.numel(), device=device), lengths)
|
||||
|
||||
def mask_mod(b, h, q_idx, kv_idx):
|
||||
return (q_idx >= kv_idx) & (doc_id[q_idx] == doc_id[kv_idx])
|
||||
|
||||
return create_block_mask(mask_mod, B=None, H=None, Q_LEN=S, KV_LEN=S, device=device)
|
||||
|
||||
def forward(self, batch):
|
||||
seq_input = self.rep_encoder(batch)
|
||||
seq_mask = self.get_sequence_causal_mask(batch["user_offsets"])
|
||||
encoder_output, moe_loss = self.seq_encoder(
|
||||
x=seq_input,
|
||||
extension={"mask": seq_mask.unsqueeze(0).unsqueeze(0)},
|
||||
)
|
||||
user_offsets = batch["user_offsets"]
|
||||
if _use_flex(seq_input.device):
|
||||
S = seq_input.shape[0] # rep_encoder 输出 [S, D],S=总 token 数
|
||||
extension = {"block_mask": self.build_block_mask(user_offsets, S)}
|
||||
else:
|
||||
seq_mask = self.get_sequence_causal_mask(user_offsets)
|
||||
extension = {"mask": seq_mask.unsqueeze(0).unsqueeze(0)}
|
||||
encoder_output, moe_loss = self.seq_encoder(x=seq_input, extension=extension)
|
||||
encoder_output = encoder_output.squeeze(0)
|
||||
pred = self.linear(encoder_output)
|
||||
pred_logits = torch.clamp(pred, min=-15.0, max=15.0)
|
||||
@@ -570,8 +649,19 @@ def load_model(ckpt_path, device='cuda:0'):
|
||||
|
||||
model.to(dev)
|
||||
model.eval()
|
||||
print(f"[INFO] Model ready. Device: {dev}")
|
||||
|
||||
use_flex = _use_flex(dev)
|
||||
print(f"[INFO] attention={'FlexAttention(block-causal)' if use_flex else 'SDPA(dense)'}, "
|
||||
f"moe={'dense' if CONFIG.get('vectorize_moe', True) else 'loop'}")
|
||||
|
||||
if CONFIG.get("compile", False):
|
||||
try:
|
||||
model = torch.compile(model, dynamic=True)
|
||||
print("[INFO] torch.compile enabled (dynamic=True)")
|
||||
except Exception as e:
|
||||
print(f"[WARNING] torch.compile failed ({e}), running eager")
|
||||
|
||||
print(f"[INFO] Model ready. Device: {dev}")
|
||||
return model, dev
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user