feat: 分块SDPA注意力(--attn chunked),按用户边界切块降O(S²)

每块~chunk_users个用户、块内因果SDPA(评测端已验证、无嵌套开销),sum(块S²)
远小于总S²。仅1次同步读切分边界。之前本地bs=16快13%被MoE同步吃掉,现MoE
同步已消除,切块红利应全露出。CONFIG.attn=chunked/chunk_users;等价测试已加。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-15 13:13:13 +08:00
parent 1249bbdbbc
commit 3d28f61a98
3 changed files with 60 additions and 5 deletions
+32 -3
View File
@@ -42,8 +42,9 @@ CONFIG = {
"filter_test_users": True, # 只处理含测试样本的用户(跳过会被丢弃的用户,省算力)
# 实测:varlen 本地快(10.28s)但评测端慢(148s,嵌套张量构造开销随batch数放大)→已退回。
# sdpa 是评测端验证最快(89.96s/58.86)。flex/compile/小batch/varlen 在评测端都更差。
# attn: "sdpa"(稠密mask,默认/评测最优) / "varlen"(本地快评测慢) / "flex"(慢)
# attn: "sdpa"(稠密mask) / "chunked"(按用户分块SDPA,降O(S²)) / "varlen"(评测慢) / "flex"(慢)
"attn": "sdpa",
"chunk_users": 16, # chunked 模式下每块的用户数(切小拼接序列以降注意力O(S²))
# 稠密MoE去掉了 model(batch) 内唯一的同步点(MoE循环的.nonzero())。若评测计时不
# synchronize,去掉同步点可能让被计时的 model(batch) 大幅缩短。本地force-sync看不出,
# 须靠提交验证。AUC中性、MoE仅占2%算力故风险极低。
@@ -408,10 +409,19 @@ def _varlen_attention(q, k, v, user_offsets):
def scaled_dot_product(q, k, v, extension):
"""注意力分发:
- varlen_offsets → 嵌套张量变长 flash(每用户独立序列、块对角因果,开销)。
- chunks → 按用户分块的 SDPA(每块块内因果,降 O(S²),无嵌套开销)。
- varlen_offsets → 嵌套张量变长 flash(评测端慢,仅对照)。
- block_mask → FlexAttention 块对角因果。
- mask(默认) → 标准 SDPA 稠密 mask(数学等价、已验证最快)。
"""
if extension is not None and extension.get("chunks") is not None:
outs = []
for s0, s1, m in extension["chunks"]:
outs.append(F.scaled_dot_product_attention(
q[:, :, s0:s1], k[:, :, s0:s1], v[:, :, s0:s1],
attn_mask=m, dropout_p=0.0, is_causal=False))
return torch.cat(outs, dim=2)
if extension is not None and extension.get("varlen_offsets") is not None:
return _varlen_attention(q, k, v, extension["varlen_offsets"])
@@ -603,6 +613,23 @@ class CTRModel(nn.Module):
out_mask = torch.tril((a == 0).to(torch.int32)).bool()
return out_mask
def build_chunks(self, user_offsets, device):
"""把拼接序列按用户边界切成每块 ~chunk_users 个用户,返回 [(s0,s1,mask), ...]。
每块块内因果,注意力 O(块内S²) 远小于 O(总S²)。仅 1 次同步(读切分边界)。"""
chunk_users = int(CONFIG.get("chunk_users", 16))
B = user_offsets.numel() - 1 # 用户数(读 shape,无同步)
idx = list(range(0, B + 1, chunk_users))
if idx[-1] != B:
idx.append(B)
bounds = user_offsets[idx].tolist() # 1 次同步:取各块的 token 边界
chunks = []
for c in range(len(bounds) - 1):
s0, s1 = bounds[c], bounds[c + 1]
local_off = user_offsets[idx[c]:idx[c + 1] + 1] - s0 # 该块内的用户边界(GPU
m = self.causal_mask_syncfree(local_off, s1 - s0, device).unsqueeze(0).unsqueeze(0)
chunks.append((s0, s1, m))
return chunks
def causal_mask_syncfree(self, user_offsets, S, device):
"""与 get_sequence_causal_mask 等价,但用 searchsorted 求每个位置的用户号,
避免 repeat_interleave(张量repeats) 的隐式同步。"""
@@ -628,7 +655,9 @@ class CTRModel(nn.Module):
seq_input = self.rep_encoder(batch)
user_offsets = batch["user_offsets"]
attn = _resolve_attn(seq_input.device)
if attn == "varlen":
if attn == "chunked":
extension = {"chunks": self.build_chunks(user_offsets, seq_input.device)}
elif attn == "varlen":
extension = {"varlen_offsets": user_offsets}
elif attn == "flex":
S = seq_input.shape[0] # rep_encoder 输出 [S, D]S=总 token 数