feat: 嵌套张量变长 flash 注意力(--attn varlen),统一 CONFIG.attn 分发
每用户当独立序列、is_causal 块对角因果,一个 flash 内核处理一 batch 内所有
用户,无稠密mask/无padding浪费/开销远低于FlexAttention。CONFIG.attn∈
{sdpa(默认),flex,varlen};bench --attn varlen;test_equiv 加 varlen 等价测试。
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
+3
-3
@@ -291,8 +291,8 @@ def _parse_args():
|
||||
help="逗号分隔的 keep_fp32_modules,如 linear,rep_encoder.input_norm")
|
||||
ap.add_argument("--feasign-none", action="store_true",
|
||||
help="不截断特征(max_feasign_per_slot=None)")
|
||||
ap.add_argument("--attn", choices=["auto", "flex", "sdpa"], default=None,
|
||||
help="注意力实现:flex=块对角FlexAttention, sdpa=稠密(原), auto=SM80自动")
|
||||
ap.add_argument("--attn", choices=["sdpa", "flex", "varlen"], default=None,
|
||||
help="注意力:sdpa=稠密(原), flex=FlexAttention, varlen=嵌套张量变长flash")
|
||||
ap.add_argument("--moe", choices=["dense", "loop"], default=None,
|
||||
help="MoE实现:dense=向量化(新), loop=逐expert循环(原)")
|
||||
ap.add_argument("--compile", action="store_true", help="开启 torch.compile")
|
||||
@@ -322,7 +322,7 @@ if __name__ == "__main__":
|
||||
if a.keep is not None:
|
||||
cfg["keep_fp32_modules"] = tuple(x for x in a.keep.split(",") if x)
|
||||
if a.attn is not None:
|
||||
cfg["use_flex_attn"] = {"auto": "auto", "flex": True, "sdpa": False}[a.attn]
|
||||
cfg["attn"] = a.attn
|
||||
if a.moe is not None:
|
||||
cfg["vectorize_moe"] = (a.moe == "dense")
|
||||
if a.compile:
|
||||
|
||||
+45
-21
@@ -40,28 +40,27 @@ CONFIG = {
|
||||
"signid_mode": "clamp", # "clamp" 或 "modulo":处理超界 sign id 的方式
|
||||
"sync_timing": False, # bench 里设 True,做 torch.cuda.synchronize 真实计时
|
||||
"filter_test_users": True, # 只处理含测试样本的用户(跳过会被丢弃的用户,省算力)
|
||||
# 实测:FlexAttention + 稠密MoE 在本模型上反而慢 5-6 倍(模型是开销瓶颈非算力瓶颈),
|
||||
# 故默认回到已验证最快的 sdpa + loop;flex/dense 仅作 bench 对照选项。
|
||||
"use_flex_attn": False, # "auto"(SM80+用flex,否则SDPA) / True / False
|
||||
# 实测(A800):sdpa+loop=15.1s 最快;flex/dense/compile/小batch 都更慢。
|
||||
# attn: "sdpa"(稠密mask,默认/已验证) / "flex"(FlexAttention,慢) / "varlen"(嵌套张量变长flash)
|
||||
"attn": "sdpa",
|
||||
"vectorize_moe": False, # True=稠密向量化MoE;False=原逐expert循环(默认,已验证更快)
|
||||
"compile": False, # 是否 torch.compile
|
||||
"compile": False, # 是否 torch.compile(实测慢5×,勿开)
|
||||
}
|
||||
|
||||
|
||||
def _use_flex(device):
|
||||
"""决定是否用 FlexAttention:auto 模式下仅在 SM80+(Ampere/A800)启用。"""
|
||||
mode = CONFIG.get("use_flex_attn", "auto")
|
||||
if not _HAS_FLEX or mode is False:
|
||||
return False
|
||||
if mode is True:
|
||||
return True
|
||||
def _resolve_attn(device):
|
||||
"""解析实际使用的注意力实现。flex 需 SM80+ 且可用,否则回退 sdpa。"""
|
||||
attn = CONFIG.get("attn", "sdpa")
|
||||
if attn == "flex":
|
||||
if not _HAS_FLEX:
|
||||
return "sdpa"
|
||||
if device is not None and device.type == "cuda":
|
||||
try:
|
||||
major, _ = torch.cuda.get_device_capability(device)
|
||||
return major >= 8
|
||||
if torch.cuda.get_device_capability(device)[0] < 8:
|
||||
return "sdpa"
|
||||
except Exception:
|
||||
return False
|
||||
return False
|
||||
return "sdpa"
|
||||
return attn
|
||||
|
||||
|
||||
def _force_fp32_io(module):
|
||||
@@ -353,12 +352,35 @@ class RepEncoder(nn.Module):
|
||||
return rep_emb
|
||||
|
||||
|
||||
def _varlen_attention(q, k, v, user_offsets):
|
||||
"""嵌套张量变长 flash 注意力:每个用户当独立序列、is_causal 块对角因果。
|
||||
一个内核处理一 batch 内所有用户,无稠密 mask、无 padding 浪费、开销低。
|
||||
q,k,v: [1, H, S, Dh];user_offsets: [B+1](S 上的用户边界)。返回 [1, H, S, Dh]。
|
||||
"""
|
||||
_, H, S, Dh = q.shape
|
||||
offs = user_offsets.to(torch.int64)
|
||||
# [1,H,S,Dh] -> [S,H,Dh]
|
||||
qv = q.squeeze(0).transpose(0, 1).contiguous()
|
||||
kv = k.squeeze(0).transpose(0, 1).contiguous()
|
||||
vv = v.squeeze(0).transpose(0, 1).contiguous()
|
||||
# 按用户边界做 jagged 嵌套张量:[B, ragged, H, Dh] -> [B, H, ragged, Dh]
|
||||
qn = torch.nested.nested_tensor_from_jagged(qv, offsets=offs).transpose(1, 2)
|
||||
kn = torch.nested.nested_tensor_from_jagged(kv, offsets=offs).transpose(1, 2)
|
||||
vn = torch.nested.nested_tensor_from_jagged(vv, offsets=offs).transpose(1, 2)
|
||||
out = F.scaled_dot_product_attention(qn, kn, vn, is_causal=True) # [B,H,ragged,Dh]
|
||||
out = out.transpose(1, 2).values() # [S, H, Dh]
|
||||
return out.transpose(0, 1).unsqueeze(0).contiguous() # [1, H, S, Dh]
|
||||
|
||||
|
||||
def scaled_dot_product(q, k, v, extension):
|
||||
"""注意力分发:
|
||||
- 若 extension 带 block_mask → FlexAttention 块对角因果(每用户只在自己序列内
|
||||
做因果注意力,避免对 ~14000 长拼接序列做 O(S²) 稠密注意力,计算量砍数十倍)。
|
||||
- 否则 → 标准 SDPA(稠密 mask,数学等价、用于回退/对照)。
|
||||
- varlen_offsets → 嵌套张量变长 flash(每用户独立序列、块对角因果,开销低)。
|
||||
- block_mask → FlexAttention 块对角因果。
|
||||
- mask(默认) → 标准 SDPA 稠密 mask(数学等价、已验证最快)。
|
||||
"""
|
||||
if extension is not None and extension.get("varlen_offsets") is not None:
|
||||
return _varlen_attention(q, k, v, extension["varlen_offsets"])
|
||||
|
||||
if extension is not None and extension.get("block_mask") is not None:
|
||||
return flex_attention(q, k, v, block_mask=extension["block_mask"])
|
||||
|
||||
@@ -562,7 +584,10 @@ class CTRModel(nn.Module):
|
||||
def forward(self, batch):
|
||||
seq_input = self.rep_encoder(batch)
|
||||
user_offsets = batch["user_offsets"]
|
||||
if _use_flex(seq_input.device):
|
||||
attn = _resolve_attn(seq_input.device)
|
||||
if attn == "varlen":
|
||||
extension = {"varlen_offsets": user_offsets}
|
||||
elif attn == "flex":
|
||||
S = seq_input.shape[0] # rep_encoder 输出 [S, D],S=总 token 数
|
||||
extension = {"block_mask": self.build_block_mask(user_offsets, S)}
|
||||
else:
|
||||
@@ -652,8 +677,7 @@ def load_model(ckpt_path, device='cuda:0'):
|
||||
model.to(dev)
|
||||
model.eval()
|
||||
|
||||
use_flex = _use_flex(dev)
|
||||
print(f"[INFO] attention={'FlexAttention(block-causal)' if use_flex else 'SDPA(dense)'}, "
|
||||
print(f"[INFO] attention={_resolve_attn(dev)}, "
|
||||
f"moe={'dense' if CONFIG.get('vectorize_moe', True) else 'loop'}")
|
||||
|
||||
if CONFIG.get("compile", False):
|
||||
|
||||
@@ -64,11 +64,32 @@ def test_moe_dense_matches_loop():
|
||||
print(f"[PASS] MoE 稠密向量化 == 逐expert循环 (max err={err:.2e}, dev={dev})")
|
||||
|
||||
|
||||
def test_varlen_matches_dense_attention():
|
||||
if not torch.cuda.is_available():
|
||||
print("[SKIP] varlen 等价测试(需 CUDA)")
|
||||
return
|
||||
torch.manual_seed(0)
|
||||
dev = "cuda"
|
||||
H, Dh = 8, 64
|
||||
offs = _offsets([10, 25, 7, 40, 18], dev)
|
||||
S = int(offs[-1])
|
||||
q = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
|
||||
k = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
|
||||
v = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
|
||||
with torch.no_grad():
|
||||
dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]})
|
||||
varlen = infer.scaled_dot_product(q, k, v, {"varlen_offsets": offs})
|
||||
err = (dense.float() - varlen.float()).abs().max().item()
|
||||
assert torch.allclose(dense.float(), varlen.float(), atol=2e-2, rtol=2e-2), \
|
||||
f"varlen 不等价 max err={err:.3e}"
|
||||
print(f"[PASS] varlen(嵌套张量) == 稠密SDPA (max err={err:.2e})")
|
||||
|
||||
|
||||
def test_flex_matches_dense_attention():
|
||||
ok = (torch.cuda.is_available() and infer._HAS_FLEX
|
||||
and torch.cuda.get_device_capability()[0] >= 8)
|
||||
if not ok:
|
||||
print("[SKIP] FlexAttention 等价测试(需 CUDA SM80+,当前环境不满足)")
|
||||
print("[SKIP] FlexAttention 等价测试(需 CUDA SM80+)")
|
||||
return
|
||||
torch.manual_seed(0)
|
||||
dev = "cuda"
|
||||
@@ -88,5 +109,6 @@ def test_flex_matches_dense_attention():
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_moe_dense_matches_loop()
|
||||
test_varlen_matches_dense_attention()
|
||||
test_flex_matches_dense_attention()
|
||||
print("[DONE] 等价测试结束")
|
||||
|
||||
Reference in New Issue
Block a user