feat/auc-recovery-plan #1
@@ -250,6 +250,11 @@ def _parse_args():
|
|||||||
help="逗号分隔的 keep_fp32_modules,如 linear,rep_encoder.input_norm")
|
help="逗号分隔的 keep_fp32_modules,如 linear,rep_encoder.input_norm")
|
||||||
ap.add_argument("--feasign-none", action="store_true",
|
ap.add_argument("--feasign-none", action="store_true",
|
||||||
help="不截断特征(max_feasign_per_slot=None)")
|
help="不截断特征(max_feasign_per_slot=None)")
|
||||||
|
ap.add_argument("--attn", choices=["auto", "flex", "sdpa"], default=None,
|
||||||
|
help="注意力实现:flex=块对角FlexAttention, sdpa=稠密(原), auto=SM80自动")
|
||||||
|
ap.add_argument("--moe", choices=["dense", "loop"], default=None,
|
||||||
|
help="MoE实现:dense=向量化(新), loop=逐expert循环(原)")
|
||||||
|
ap.add_argument("--compile", action="store_true", help="开启 torch.compile")
|
||||||
ap.add_argument("--rebuild", action="store_true", help="强制重建过滤缓存")
|
ap.add_argument("--rebuild", action="store_true", help="强制重建过滤缓存")
|
||||||
return ap.parse_args()
|
return ap.parse_args()
|
||||||
|
|
||||||
@@ -273,5 +278,11 @@ if __name__ == "__main__":
|
|||||||
cfg["merge_threshold"] = a.merge_th
|
cfg["merge_threshold"] = a.merge_th
|
||||||
if a.keep is not None:
|
if a.keep is not None:
|
||||||
cfg["keep_fp32_modules"] = tuple(x for x in a.keep.split(",") if x)
|
cfg["keep_fp32_modules"] = tuple(x for x in a.keep.split(",") if x)
|
||||||
|
if a.attn is not None:
|
||||||
|
cfg["use_flex_attn"] = {"auto": "auto", "flex": True, "sdpa": False}[a.attn]
|
||||||
|
if a.moe is not None:
|
||||||
|
cfg["vectorize_moe"] = (a.moe == "dense")
|
||||||
|
if a.compile:
|
||||||
|
cfg["compile"] = True
|
||||||
mf = None if a.feasign_none else {1: 2}
|
mf = None if a.feasign_none else {1: 2}
|
||||||
run_once(cfg, batch_size=a.bs, max_batches=a.smoke, max_feasign_per_slot=mf, rebuild=a.rebuild)
|
run_once(cfg, batch_size=a.bs, max_batches=a.smoke, max_feasign_per_slot=mf, rebuild=a.rebuild)
|
||||||
|
|||||||
+119
-29
@@ -17,6 +17,15 @@ import torch.nn.functional as F
|
|||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# FlexAttention(块对角因果注意力,需 PyTorch 2.5+ 且 GPU 计算能力 >= 8.0 / Ampere)
|
||||||
|
try:
|
||||||
|
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
||||||
|
_HAS_FLEX = True
|
||||||
|
except Exception:
|
||||||
|
flex_attention = None
|
||||||
|
create_block_mask = None
|
||||||
|
_HAS_FLEX = False
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# 实验配置开关板
|
# 实验配置开关板
|
||||||
@@ -31,9 +40,28 @@ CONFIG = {
|
|||||||
"signid_mode": "clamp", # "clamp" 或 "modulo":处理超界 sign id 的方式
|
"signid_mode": "clamp", # "clamp" 或 "modulo":处理超界 sign id 的方式
|
||||||
"sync_timing": False, # bench 里设 True,做 torch.cuda.synchronize 真实计时
|
"sync_timing": False, # bench 里设 True,做 torch.cuda.synchronize 真实计时
|
||||||
"filter_test_users": True, # 只处理含测试样本的用户(跳过会被丢弃的用户,省算力)
|
"filter_test_users": True, # 只处理含测试样本的用户(跳过会被丢弃的用户,省算力)
|
||||||
|
"use_flex_attn": "auto", # "auto"(SM80+用flex,否则SDPA) / True / False
|
||||||
|
"vectorize_moe": True, # True=稠密向量化MoE(无Python循环/同步);False=原逐expert循环
|
||||||
|
"compile": False, # 是否 torch.compile(图理干净后再开)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _use_flex(device):
|
||||||
|
"""决定是否用 FlexAttention:auto 模式下仅在 SM80+(Ampere/A800)启用。"""
|
||||||
|
mode = CONFIG.get("use_flex_attn", "auto")
|
||||||
|
if not _HAS_FLEX or mode is False:
|
||||||
|
return False
|
||||||
|
if mode is True:
|
||||||
|
return True
|
||||||
|
if device is not None and device.type == "cuda":
|
||||||
|
try:
|
||||||
|
major, _ = torch.cuda.get_device_capability(device)
|
||||||
|
return major >= 8
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _force_fp32_io(module):
|
def _force_fp32_io(module):
|
||||||
"""让某个模块在 FP16 模型里以 FP32 计算:输入转 FP32、输出转回 FP16。
|
"""让某个模块在 FP16 模型里以 FP32 计算:输入转 FP32、输出转回 FP16。
|
||||||
用于 keep_fp32_modules 指定的精度敏感层(如最终输出头、LayerNorm)。"""
|
用于 keep_fp32_modules 指定的精度敏感层(如最终输出头、LayerNorm)。"""
|
||||||
@@ -324,7 +352,14 @@ class RepEncoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def scaled_dot_product(q, k, v, extension):
|
def scaled_dot_product(q, k, v, extension):
|
||||||
"""使用 PyTorch SDPA 后端(自动启用 Flash Attention / Memory Efficient Attention)"""
|
"""注意力分发:
|
||||||
|
- 若 extension 带 block_mask → FlexAttention 块对角因果(每用户只在自己序列内
|
||||||
|
做因果注意力,避免对 ~14000 长拼接序列做 O(S²) 稠密注意力,计算量砍数十倍)。
|
||||||
|
- 否则 → 标准 SDPA(稠密 mask,数学等价、用于回退/对照)。
|
||||||
|
"""
|
||||||
|
if extension is not None and extension.get("block_mask") is not None:
|
||||||
|
return flex_attention(q, k, v, block_mask=extension["block_mask"])
|
||||||
|
|
||||||
if extension is not None and "mask" in extension:
|
if extension is not None and "mask" in extension:
|
||||||
attn_mask = extension["mask"].to(device=q.device)
|
attn_mask = extension["mask"].to(device=q.device)
|
||||||
else:
|
else:
|
||||||
@@ -369,6 +404,29 @@ class TopKGate(nn.Module):
|
|||||||
|
|
||||||
return topk_idx, topk_score, probs
|
return topk_idx, topk_score, probs
|
||||||
|
|
||||||
|
def _smoe_forward_loop(moe, x):
|
||||||
|
"""原始逐 expert 循环实现(保留作数值等价对照/回退)。"""
|
||||||
|
B, S, D = x.shape
|
||||||
|
topk_idx, topk_score, probs = moe.gate(x)
|
||||||
|
out = torch.zeros_like(x)
|
||||||
|
x_flat = x.reshape(-1, D)
|
||||||
|
idx_flat = topk_idx.reshape(-1, moe.k)
|
||||||
|
score_flat = topk_score.reshape(-1, moe.k)
|
||||||
|
out_flat = out.reshape(-1, D)
|
||||||
|
for i in range(moe.num_experts):
|
||||||
|
mask = (idx_flat == i)
|
||||||
|
token_idx, k_idx = mask.nonzero(as_tuple=True)
|
||||||
|
if token_idx.numel() == 0:
|
||||||
|
continue
|
||||||
|
selected_x = x_flat[token_idx]
|
||||||
|
expert_out = moe.experts[i](selected_x)
|
||||||
|
weight = score_flat[token_idx, k_idx].unsqueeze(-1)
|
||||||
|
out_flat[token_idx] += expert_out * weight
|
||||||
|
importance = probs.sum(dim=(0, 1))
|
||||||
|
moe_loss = (importance.std() / (importance.mean() + 1e-6))
|
||||||
|
return out, moe_loss
|
||||||
|
|
||||||
|
|
||||||
class SMoE(nn.Module):
|
class SMoE(nn.Module):
|
||||||
def __init__(self, d_model, dim_ff, num_experts, k=2):
|
def __init__(self, d_model, dim_ff, num_experts, k=2):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -380,37 +438,43 @@ class SMoE(nn.Module):
|
|||||||
])
|
])
|
||||||
|
|
||||||
self.gate = TopKGate(d_model, num_experts, k=k)
|
self.gate = TopKGate(d_model, num_experts, k=k)
|
||||||
|
self._stacked = False
|
||||||
|
|
||||||
|
def _stack_weights(self):
|
||||||
|
"""把各 expert 的 fc1/fc2 权重堆叠成单一张量,供批量 matmul。
|
||||||
|
延迟到首次 forward 调用:此时已完成 expert 合并与 half()/to(device)。"""
|
||||||
|
self.register_buffer("W1", torch.stack([e.fc1.weight for e in self.experts]).contiguous()) # [E,F,D]
|
||||||
|
self.register_buffer("b1", torch.stack([e.fc1.bias for e in self.experts]).contiguous()) # [E,F]
|
||||||
|
self.register_buffer("W2", torch.stack([e.fc2.weight for e in self.experts]).contiguous()) # [E,D,F]
|
||||||
|
self.register_buffer("b2", torch.stack([e.fc2.bias for e in self.experts]).contiguous()) # [E,D]
|
||||||
|
self._stacked = True
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# x: [B,S,D]
|
# x: [B,S,D]
|
||||||
B, S, D = x.shape
|
if not CONFIG.get("vectorize_moe", True):
|
||||||
|
return _smoe_forward_loop(self, x)
|
||||||
|
|
||||||
|
if not self._stacked:
|
||||||
|
self._stack_weights()
|
||||||
|
|
||||||
|
B, S, D = x.shape
|
||||||
topk_idx, topk_score, probs = self.gate(x)
|
topk_idx, topk_score, probs = self.gate(x)
|
||||||
|
|
||||||
out = torch.zeros_like(x)
|
xf = x.reshape(-1, D) # [N, D]
|
||||||
|
# 稠密计算所有 expert(GPU 友好、无 Python 循环/同步/gather-scatter):
|
||||||
|
h = torch.einsum("nd,efd->enf", xf, self.W1) + self.b1.unsqueeze(1) # [E,N,F]
|
||||||
|
h = F.relu(h)
|
||||||
|
o = torch.einsum("enf,edf->end", h, self.W2) + self.b2.unsqueeze(1) # [E,N,D]
|
||||||
|
|
||||||
# flatten
|
# 按每个 token 的 top-k 选取并加权(与逐 expert 循环数学等价)
|
||||||
x_flat = x.reshape(-1, D) # [B*S, D]
|
o = o.permute(1, 0, 2) # [N, E, D]
|
||||||
idx_flat = topk_idx.reshape(-1, self.k) # [B*S, k]
|
idx = topk_idx.reshape(-1, self.k) # [N, k]
|
||||||
score_flat = topk_score.reshape(-1, self.k)
|
sc = topk_score.reshape(-1, self.k) # [N, k]
|
||||||
out_flat = out.reshape(-1, D) # 提前 reshape,避免循环内重复
|
sel = torch.gather(o, 1, idx.unsqueeze(-1).expand(-1, -1, D)) # [N, k, D]
|
||||||
|
out = (sel * sc.unsqueeze(-1)).sum(dim=1).reshape(B, S, D)
|
||||||
|
|
||||||
for i in range(self.num_experts):
|
importance = probs.sum(dim=(0, 1)) # [E]
|
||||||
# 找到被路由到 expert i 的 token
|
|
||||||
mask = (idx_flat == i) # [B*S, k]
|
|
||||||
|
|
||||||
token_idx, k_idx = mask.nonzero(as_tuple=True)
|
|
||||||
if token_idx.numel() == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
selected_x = x_flat[token_idx] # [N, D]
|
|
||||||
expert_out = self.experts[i](selected_x) # [N, D]
|
|
||||||
weight = score_flat[token_idx, k_idx].unsqueeze(-1)
|
|
||||||
out_flat[token_idx] += expert_out * weight
|
|
||||||
|
|
||||||
importance = probs.sum(dim=(0,1)) # [E]
|
|
||||||
moe_loss = (importance.std() / (importance.mean() + 1e-6))
|
moe_loss = (importance.std() / (importance.mean() + 1e-6))
|
||||||
|
|
||||||
return out, moe_loss
|
return out, moe_loss
|
||||||
|
|
||||||
|
|
||||||
@@ -481,13 +545,28 @@ class CTRModel(nn.Module):
|
|||||||
out_mask = torch.tril((a == 0).to(torch.int32)).bool()
|
out_mask = torch.tril((a == 0).to(torch.int32)).bool()
|
||||||
return out_mask
|
return out_mask
|
||||||
|
|
||||||
|
def build_block_mask(self, user_offsets, S):
|
||||||
|
"""FlexAttention 块对角因果 mask:q 只能 attend 同一用户且 kv<=q 的位置。"""
|
||||||
|
lengths = (user_offsets[1:] - user_offsets[:-1]).view(-1)
|
||||||
|
device = user_offsets.device
|
||||||
|
doc_id = torch.repeat_interleave(
|
||||||
|
torch.arange(lengths.numel(), device=device), lengths)
|
||||||
|
|
||||||
|
def mask_mod(b, h, q_idx, kv_idx):
|
||||||
|
return (q_idx >= kv_idx) & (doc_id[q_idx] == doc_id[kv_idx])
|
||||||
|
|
||||||
|
return create_block_mask(mask_mod, B=None, H=None, Q_LEN=S, KV_LEN=S, device=device)
|
||||||
|
|
||||||
def forward(self, batch):
|
def forward(self, batch):
|
||||||
seq_input = self.rep_encoder(batch)
|
seq_input = self.rep_encoder(batch)
|
||||||
seq_mask = self.get_sequence_causal_mask(batch["user_offsets"])
|
user_offsets = batch["user_offsets"]
|
||||||
encoder_output, moe_loss = self.seq_encoder(
|
if _use_flex(seq_input.device):
|
||||||
x=seq_input,
|
S = seq_input.shape[0] # rep_encoder 输出 [S, D],S=总 token 数
|
||||||
extension={"mask": seq_mask.unsqueeze(0).unsqueeze(0)},
|
extension = {"block_mask": self.build_block_mask(user_offsets, S)}
|
||||||
)
|
else:
|
||||||
|
seq_mask = self.get_sequence_causal_mask(user_offsets)
|
||||||
|
extension = {"mask": seq_mask.unsqueeze(0).unsqueeze(0)}
|
||||||
|
encoder_output, moe_loss = self.seq_encoder(x=seq_input, extension=extension)
|
||||||
encoder_output = encoder_output.squeeze(0)
|
encoder_output = encoder_output.squeeze(0)
|
||||||
pred = self.linear(encoder_output)
|
pred = self.linear(encoder_output)
|
||||||
pred_logits = torch.clamp(pred, min=-15.0, max=15.0)
|
pred_logits = torch.clamp(pred, min=-15.0, max=15.0)
|
||||||
@@ -570,8 +649,19 @@ def load_model(ckpt_path, device='cuda:0'):
|
|||||||
|
|
||||||
model.to(dev)
|
model.to(dev)
|
||||||
model.eval()
|
model.eval()
|
||||||
print(f"[INFO] Model ready. Device: {dev}")
|
|
||||||
|
|
||||||
|
use_flex = _use_flex(dev)
|
||||||
|
print(f"[INFO] attention={'FlexAttention(block-causal)' if use_flex else 'SDPA(dense)'}, "
|
||||||
|
f"moe={'dense' if CONFIG.get('vectorize_moe', True) else 'loop'}")
|
||||||
|
|
||||||
|
if CONFIG.get("compile", False):
|
||||||
|
try:
|
||||||
|
model = torch.compile(model, dynamic=True)
|
||||||
|
print("[INFO] torch.compile enabled (dynamic=True)")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[WARNING] torch.compile failed ({e}), running eager")
|
||||||
|
|
||||||
|
print(f"[INFO] Model ready. Device: {dev}")
|
||||||
return model, dev
|
return model, dev
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,92 @@
|
|||||||
|
"""Phase B 数值等价测试:新实现 vs 原实现。子进程跑:
|
||||||
|
|
||||||
|
%cd /home/aistudio/code
|
||||||
|
!python tests/test_equiv.py
|
||||||
|
|
||||||
|
- MoE 稠密向量化 vs 原逐 expert 循环(CPU/GPU 都可,FP32)
|
||||||
|
- FlexAttention 块对角因果 vs 稠密 SDPA(需 CUDA SM80+,否则自动跳过)
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# baseline 把依赖装在 --target 目录;import 前补 sys.path
|
||||||
|
for _p in ("/home/aistudio/external-libraries", "/home/aistudio/libraries",
|
||||||
|
os.path.abspath("../libraries"), os.path.abspath("./libraries")):
|
||||||
|
if os.path.isdir(_p) and _p not in sys.path:
|
||||||
|
sys.path.insert(0, _p)
|
||||||
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import infer
|
||||||
|
|
||||||
|
|
||||||
|
def _offsets(lengths, device):
|
||||||
|
offs = [0]
|
||||||
|
for L in lengths:
|
||||||
|
offs.append(offs[-1] + L)
|
||||||
|
return torch.tensor(offs, dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
def _dense_causal_mask(offs):
|
||||||
|
"""同用户 + 因果(tril),与 CTRModel.get_sequence_causal_mask 语义一致。"""
|
||||||
|
lengths = (offs[1:] - offs[:-1]).view(-1)
|
||||||
|
idx = torch.repeat_interleave(
|
||||||
|
torch.arange(lengths.numel(), device=offs.device), lengths)
|
||||||
|
same = idx.view(1, -1) == idx.view(-1, 1)
|
||||||
|
causal = torch.tril(torch.ones_like(same, dtype=torch.bool))
|
||||||
|
return same & causal
|
||||||
|
|
||||||
|
|
||||||
|
def _block_mask(offs, S):
|
||||||
|
lengths = (offs[1:] - offs[:-1]).view(-1)
|
||||||
|
doc_id = torch.repeat_interleave(
|
||||||
|
torch.arange(lengths.numel(), device=offs.device), lengths)
|
||||||
|
|
||||||
|
def mask_mod(b, h, q_idx, kv_idx):
|
||||||
|
return (q_idx >= kv_idx) & (doc_id[q_idx] == doc_id[kv_idx])
|
||||||
|
|
||||||
|
return infer.create_block_mask(mask_mod, B=None, H=None, Q_LEN=S, KV_LEN=S,
|
||||||
|
device=offs.device)
|
||||||
|
|
||||||
|
|
||||||
|
def test_moe_dense_matches_loop():
|
||||||
|
torch.manual_seed(0)
|
||||||
|
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
moe = infer.SMoE(d_model=512, dim_ff=1024, num_experts=8, k=2).to(dev).eval()
|
||||||
|
x = torch.randn(1, 200, 512, device=dev)
|
||||||
|
with torch.no_grad():
|
||||||
|
ref, _ = infer._smoe_forward_loop(moe, x)
|
||||||
|
infer.CONFIG["vectorize_moe"] = True
|
||||||
|
new, _ = moe(x)
|
||||||
|
err = (ref - new).abs().max().item()
|
||||||
|
assert torch.allclose(ref, new, atol=1e-4, rtol=1e-4), f"MoE 不等价 max err={err:.3e}"
|
||||||
|
print(f"[PASS] MoE 稠密向量化 == 逐expert循环 (max err={err:.2e}, dev={dev})")
|
||||||
|
|
||||||
|
|
||||||
|
def test_flex_matches_dense_attention():
|
||||||
|
ok = (torch.cuda.is_available() and infer._HAS_FLEX
|
||||||
|
and torch.cuda.get_device_capability()[0] >= 8)
|
||||||
|
if not ok:
|
||||||
|
print("[SKIP] FlexAttention 等价测试(需 CUDA SM80+,当前环境不满足)")
|
||||||
|
return
|
||||||
|
torch.manual_seed(0)
|
||||||
|
dev = "cuda"
|
||||||
|
H, Dh = 8, 64
|
||||||
|
offs = _offsets([10, 25, 7, 40, 18], dev)
|
||||||
|
S = int(offs[-1])
|
||||||
|
q = torch.randn(1, H, S, Dh, device=dev)
|
||||||
|
k = torch.randn(1, H, S, Dh, device=dev)
|
||||||
|
v = torch.randn(1, H, S, Dh, device=dev)
|
||||||
|
with torch.no_grad():
|
||||||
|
dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]})
|
||||||
|
flex = infer.scaled_dot_product(q, k, v, {"block_mask": _block_mask(offs, S)})
|
||||||
|
err = (dense - flex).abs().max().item()
|
||||||
|
assert torch.allclose(dense, flex, atol=2e-2, rtol=2e-2), f"Flex 不等价 max err={err:.3e}"
|
||||||
|
print(f"[PASS] FlexAttention 块对角 == 稠密SDPA (max err={err:.2e})")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_moe_dense_matches_loop()
|
||||||
|
test_flex_matches_dense_attention()
|
||||||
|
print("[DONE] 等价测试结束")
|
||||||
Reference in New Issue
Block a user