feat/auc-recovery-plan #1

Merged
Serendipity merged 20 commits from feat/auc-recovery-plan into main 2026-06-15 12:33:32 +08:00
3 changed files with 74 additions and 28 deletions
Showing only changes of commit 7791674a32 - Show all commits
+3 -3
View File
@@ -291,8 +291,8 @@ def _parse_args():
help="逗号分隔的 keep_fp32_modules,如 linear,rep_encoder.input_norm")
ap.add_argument("--feasign-none", action="store_true",
help="不截断特征(max_feasign_per_slot=None")
ap.add_argument("--attn", choices=["auto", "flex", "sdpa"], default=None,
help="注意力实现flex=块对角FlexAttention, sdpa=稠密(原), auto=SM80自动")
ap.add_argument("--attn", choices=["sdpa", "flex", "varlen"], default=None,
help="注意力:sdpa=稠密(原), flex=FlexAttention, varlen=嵌套张量变长flash")
ap.add_argument("--moe", choices=["dense", "loop"], default=None,
help="MoE实现:dense=向量化(新), loop=逐expert循环(原)")
ap.add_argument("--compile", action="store_true", help="开启 torch.compile")
@@ -322,7 +322,7 @@ if __name__ == "__main__":
if a.keep is not None:
cfg["keep_fp32_modules"] = tuple(x for x in a.keep.split(",") if x)
if a.attn is not None:
cfg["use_flex_attn"] = {"auto": "auto", "flex": True, "sdpa": False}[a.attn]
cfg["attn"] = a.attn
if a.moe is not None:
cfg["vectorize_moe"] = (a.moe == "dense")
if a.compile:
+48 -24
View File
@@ -40,28 +40,27 @@ CONFIG = {
"signid_mode": "clamp", # "clamp" 或 "modulo":处理超界 sign id 的方式
"sync_timing": False, # bench 里设 True,做 torch.cuda.synchronize 真实计时
"filter_test_users": True, # 只处理含测试样本的用户(跳过会被丢弃的用户,省算力)
# 实测FlexAttention + 稠密MoE 在本模型上反而慢 5-6 倍(模型是开销瓶颈非算力瓶颈),
# 故默认回到已验证最快的 sdpa + loopflex/dense 仅作 bench 对照选项。
"use_flex_attn": False, # "auto"(SM80+用flex,否则SDPA) / True / False
# 实测(A800)sdpa+loop=15.1s 最快;flex/dense/compile/小batch 都更慢。
# attn: "sdpa"(稠密mask,默认/已验证) / "flex"(FlexAttention,慢) / "varlen"(嵌套张量变长flash)
"attn": "sdpa",
"vectorize_moe": False, # True=稠密向量化MoEFalse=原逐expert循环(默认,已验证更快)
"compile": False, # 是否 torch.compile
"compile": False, # 是否 torch.compile(实测慢5×,勿开)
}
def _use_flex(device):
"""决定是否用 FlexAttentionauto 模式下仅在 SM80+Ampere/A800)启用"""
mode = CONFIG.get("use_flex_attn", "auto")
if not _HAS_FLEX or mode is False:
return False
if mode is True:
return True
if device is not None and device.type == "cuda":
try:
major, _ = torch.cuda.get_device_capability(device)
return major >= 8
except Exception:
return False
return False
def _resolve_attn(device):
"""解析实际使用的注意力实现。flex 需 SM80+ 且可用,否则回退 sdpa"""
attn = CONFIG.get("attn", "sdpa")
if attn == "flex":
if not _HAS_FLEX:
return "sdpa"
if device is not None and device.type == "cuda":
try:
if torch.cuda.get_device_capability(device)[0] < 8:
return "sdpa"
except Exception:
return "sdpa"
return attn
def _force_fp32_io(module):
@@ -353,12 +352,35 @@ class RepEncoder(nn.Module):
return rep_emb
def _varlen_attention(q, k, v, user_offsets):
"""嵌套张量变长 flash 注意力:每个用户当独立序列、is_causal 块对角因果。
一个内核处理一 batch 内所有用户,无稠密 mask、无 padding 浪费、开销低。
q,k,v: [1, H, S, Dh]user_offsets: [B+1]S 上的用户边界)。返回 [1, H, S, Dh]。
"""
_, H, S, Dh = q.shape
offs = user_offsets.to(torch.int64)
# [1,H,S,Dh] -> [S,H,Dh]
qv = q.squeeze(0).transpose(0, 1).contiguous()
kv = k.squeeze(0).transpose(0, 1).contiguous()
vv = v.squeeze(0).transpose(0, 1).contiguous()
# 按用户边界做 jagged 嵌套张量:[B, ragged, H, Dh] -> [B, H, ragged, Dh]
qn = torch.nested.nested_tensor_from_jagged(qv, offsets=offs).transpose(1, 2)
kn = torch.nested.nested_tensor_from_jagged(kv, offsets=offs).transpose(1, 2)
vn = torch.nested.nested_tensor_from_jagged(vv, offsets=offs).transpose(1, 2)
out = F.scaled_dot_product_attention(qn, kn, vn, is_causal=True) # [B,H,ragged,Dh]
out = out.transpose(1, 2).values() # [S, H, Dh]
return out.transpose(0, 1).unsqueeze(0).contiguous() # [1, H, S, Dh]
def scaled_dot_product(q, k, v, extension):
"""注意力分发:
- 若 extension 带 block_mask → FlexAttention 块对角因果(每用户只在自己序列内
做因果注意力,避免对 ~14000 长拼接序列做 O(S²) 稠密注意力,计算量砍数十倍)
- 否则 → 标准 SDPA稠密 mask数学等价、用于回退/对照)。
- varlen_offsets → 嵌套张量变长 flash(每用户独立序列、块对角因果,开销低)。
- block_mask → FlexAttention 块对角因果
- mask(默认) → 标准 SDPA 稠密 mask数学等价、已验证最快)。
"""
if extension is not None and extension.get("varlen_offsets") is not None:
return _varlen_attention(q, k, v, extension["varlen_offsets"])
if extension is not None and extension.get("block_mask") is not None:
return flex_attention(q, k, v, block_mask=extension["block_mask"])
@@ -562,7 +584,10 @@ class CTRModel(nn.Module):
def forward(self, batch):
seq_input = self.rep_encoder(batch)
user_offsets = batch["user_offsets"]
if _use_flex(seq_input.device):
attn = _resolve_attn(seq_input.device)
if attn == "varlen":
extension = {"varlen_offsets": user_offsets}
elif attn == "flex":
S = seq_input.shape[0] # rep_encoder 输出 [S, D]S=总 token 数
extension = {"block_mask": self.build_block_mask(user_offsets, S)}
else:
@@ -652,8 +677,7 @@ def load_model(ckpt_path, device='cuda:0'):
model.to(dev)
model.eval()
use_flex = _use_flex(dev)
print(f"[INFO] attention={'FlexAttention(block-causal)' if use_flex else 'SDPA(dense)'}, "
print(f"[INFO] attention={_resolve_attn(dev)}, "
f"moe={'dense' if CONFIG.get('vectorize_moe', True) else 'loop'}")
if CONFIG.get("compile", False):
+23 -1
View File
@@ -64,11 +64,32 @@ def test_moe_dense_matches_loop():
print(f"[PASS] MoE 稠密向量化 == 逐expert循环 (max err={err:.2e}, dev={dev})")
def test_varlen_matches_dense_attention():
if not torch.cuda.is_available():
print("[SKIP] varlen 等价测试(需 CUDA")
return
torch.manual_seed(0)
dev = "cuda"
H, Dh = 8, 64
offs = _offsets([10, 25, 7, 40, 18], dev)
S = int(offs[-1])
q = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
k = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
v = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
with torch.no_grad():
dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]})
varlen = infer.scaled_dot_product(q, k, v, {"varlen_offsets": offs})
err = (dense.float() - varlen.float()).abs().max().item()
assert torch.allclose(dense.float(), varlen.float(), atol=2e-2, rtol=2e-2), \
f"varlen 不等价 max err={err:.3e}"
print(f"[PASS] varlen(嵌套张量) == 稠密SDPA (max err={err:.2e})")
def test_flex_matches_dense_attention():
ok = (torch.cuda.is_available() and infer._HAS_FLEX
and torch.cuda.get_device_capability()[0] >= 8)
if not ok:
print("[SKIP] FlexAttention 等价测试(需 CUDA SM80+,当前环境不满足")
print("[SKIP] FlexAttention 等价测试(需 CUDA SM80+")
return
torch.manual_seed(0)
dev = "cuda"
@@ -88,5 +109,6 @@ def test_flex_matches_dense_attention():
if __name__ == "__main__":
test_moe_dense_matches_loop()
test_varlen_matches_dense_attention()
test_flex_matches_dense_attention()
print("[DONE] 等价测试结束")