feat: 嵌套张量变长 flash 注意力(--attn varlen),统一 CONFIG.attn 分发

每用户当独立序列、is_causal 块对角因果,一个 flash 内核处理一 batch 内所有
用户,无稠密mask/无padding浪费/开销远低于FlexAttention。CONFIG.attn∈
{sdpa(默认),flex,varlen};bench --attn varlen;test_equiv 加 varlen 等价测试。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-15 09:06:11 +08:00
parent 9eaf5f5511
commit 7791674a32
3 changed files with 74 additions and 28 deletions
+48 -24
View File
@@ -40,28 +40,27 @@ CONFIG = {
"signid_mode": "clamp", # "clamp" 或 "modulo":处理超界 sign id 的方式
"sync_timing": False, # bench 里设 True,做 torch.cuda.synchronize 真实计时
"filter_test_users": True, # 只处理含测试样本的用户(跳过会被丢弃的用户,省算力)
# 实测FlexAttention + 稠密MoE 在本模型上反而慢 5-6 倍(模型是开销瓶颈非算力瓶颈),
# 故默认回到已验证最快的 sdpa + loopflex/dense 仅作 bench 对照选项。
"use_flex_attn": False, # "auto"(SM80+用flex,否则SDPA) / True / False
# 实测(A800)sdpa+loop=15.1s 最快;flex/dense/compile/小batch 都更慢。
# attn: "sdpa"(稠密mask,默认/已验证) / "flex"(FlexAttention,慢) / "varlen"(嵌套张量变长flash)
"attn": "sdpa",
"vectorize_moe": False, # True=稠密向量化MoEFalse=原逐expert循环(默认,已验证更快)
"compile": False, # 是否 torch.compile
"compile": False, # 是否 torch.compile(实测慢5×,勿开)
}
def _use_flex(device):
"""决定是否用 FlexAttentionauto 模式下仅在 SM80+Ampere/A800)启用"""
mode = CONFIG.get("use_flex_attn", "auto")
if not _HAS_FLEX or mode is False:
return False
if mode is True:
return True
if device is not None and device.type == "cuda":
try:
major, _ = torch.cuda.get_device_capability(device)
return major >= 8
except Exception:
return False
return False
def _resolve_attn(device):
"""解析实际使用的注意力实现。flex 需 SM80+ 且可用,否则回退 sdpa"""
attn = CONFIG.get("attn", "sdpa")
if attn == "flex":
if not _HAS_FLEX:
return "sdpa"
if device is not None and device.type == "cuda":
try:
if torch.cuda.get_device_capability(device)[0] < 8:
return "sdpa"
except Exception:
return "sdpa"
return attn
def _force_fp32_io(module):
@@ -353,12 +352,35 @@ class RepEncoder(nn.Module):
return rep_emb
def _varlen_attention(q, k, v, user_offsets):
"""嵌套张量变长 flash 注意力:每个用户当独立序列、is_causal 块对角因果。
一个内核处理一 batch 内所有用户,无稠密 mask、无 padding 浪费、开销低。
q,k,v: [1, H, S, Dh]user_offsets: [B+1]S 上的用户边界)。返回 [1, H, S, Dh]。
"""
_, H, S, Dh = q.shape
offs = user_offsets.to(torch.int64)
# [1,H,S,Dh] -> [S,H,Dh]
qv = q.squeeze(0).transpose(0, 1).contiguous()
kv = k.squeeze(0).transpose(0, 1).contiguous()
vv = v.squeeze(0).transpose(0, 1).contiguous()
# 按用户边界做 jagged 嵌套张量:[B, ragged, H, Dh] -> [B, H, ragged, Dh]
qn = torch.nested.nested_tensor_from_jagged(qv, offsets=offs).transpose(1, 2)
kn = torch.nested.nested_tensor_from_jagged(kv, offsets=offs).transpose(1, 2)
vn = torch.nested.nested_tensor_from_jagged(vv, offsets=offs).transpose(1, 2)
out = F.scaled_dot_product_attention(qn, kn, vn, is_causal=True) # [B,H,ragged,Dh]
out = out.transpose(1, 2).values() # [S, H, Dh]
return out.transpose(0, 1).unsqueeze(0).contiguous() # [1, H, S, Dh]
def scaled_dot_product(q, k, v, extension):
"""注意力分发:
- 若 extension 带 block_mask → FlexAttention 块对角因果(每用户只在自己序列内
做因果注意力,避免对 ~14000 长拼接序列做 O(S²) 稠密注意力,计算量砍数十倍)
- 否则 → 标准 SDPA稠密 mask数学等价、用于回退/对照)。
- varlen_offsets → 嵌套张量变长 flash(每用户独立序列、块对角因果,开销低)。
- block_mask → FlexAttention 块对角因果
- mask(默认) → 标准 SDPA 稠密 mask数学等价、已验证最快)。
"""
if extension is not None and extension.get("varlen_offsets") is not None:
return _varlen_attention(q, k, v, extension["varlen_offsets"])
if extension is not None and extension.get("block_mask") is not None:
return flex_attention(q, k, v, block_mask=extension["block_mask"])
@@ -562,7 +584,10 @@ class CTRModel(nn.Module):
def forward(self, batch):
seq_input = self.rep_encoder(batch)
user_offsets = batch["user_offsets"]
if _use_flex(seq_input.device):
attn = _resolve_attn(seq_input.device)
if attn == "varlen":
extension = {"varlen_offsets": user_offsets}
elif attn == "flex":
S = seq_input.shape[0] # rep_encoder 输出 [S, D]S=总 token 数
extension = {"block_mask": self.build_block_mask(user_offsets, S)}
else:
@@ -652,8 +677,7 @@ def load_model(ckpt_path, device='cuda:0'):
model.to(dev)
model.eval()
use_flex = _use_flex(dev)
print(f"[INFO] attention={'FlexAttention(block-causal)' if use_flex else 'SDPA(dense)'}, "
print(f"[INFO] attention={_resolve_attn(dev)}, "
f"moe={'dense' if CONFIG.get('vectorize_moe', True) else 'loop'}")
if CONFIG.get("compile", False):