feat: 分块SDPA注意力(--attn chunked),按用户边界切块降O(S²)

每块~chunk_users个用户、块内因果SDPA(评测端已验证、无嵌套开销),sum(块S²)
远小于总S²。仅1次同步读切分边界。之前本地bs=16快13%被MoE同步吃掉,现MoE
同步已消除,切块红利应全露出。CONFIG.attn=chunked/chunk_users;等价测试已加。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-15 13:13:13 +08:00
parent 1249bbdbbc
commit 3d28f61a98
3 changed files with 60 additions and 5 deletions
+5 -2
View File
@@ -291,8 +291,9 @@ def _parse_args():
help="逗号分隔的 keep_fp32_modules,如 linear,rep_encoder.input_norm") help="逗号分隔的 keep_fp32_modules,如 linear,rep_encoder.input_norm")
ap.add_argument("--feasign-none", action="store_true", ap.add_argument("--feasign-none", action="store_true",
help="不截断特征(max_feasign_per_slot=None") help="不截断特征(max_feasign_per_slot=None")
ap.add_argument("--attn", choices=["sdpa", "flex", "varlen"], default=None, ap.add_argument("--attn", choices=["sdpa", "chunked", "flex", "varlen"], default=None,
help="注意力:sdpa=稠密(原), flex=FlexAttention, varlen=嵌套张量变长flash") help="注意力:sdpa=稠密, chunked=按用户分块SDPA, flex/varlen=对照")
ap.add_argument("--chunk-users", type=int, default=None, help="chunked 模式每块用户数")
ap.add_argument("--moe", choices=["dense", "loop"], default=None, ap.add_argument("--moe", choices=["dense", "loop"], default=None,
help="MoE实现:dense=向量化(新), loop=逐expert循环(原)") help="MoE实现:dense=向量化(新), loop=逐expert循环(原)")
ap.add_argument("--compile", action="store_true", help="开启 torch.compile") ap.add_argument("--compile", action="store_true", help="开启 torch.compile")
@@ -324,6 +325,8 @@ if __name__ == "__main__":
cfg["keep_fp32_modules"] = tuple(x for x in a.keep.split(",") if x) cfg["keep_fp32_modules"] = tuple(x for x in a.keep.split(",") if x)
if a.attn is not None: if a.attn is not None:
cfg["attn"] = a.attn cfg["attn"] = a.attn
if a.chunk_users is not None:
cfg["chunk_users"] = a.chunk_users
if a.moe is not None: if a.moe is not None:
cfg["vectorize_moe"] = (a.moe == "dense") cfg["vectorize_moe"] = (a.moe == "dense")
if a.emb_fp16: if a.emb_fp16:
+32 -3
View File
@@ -42,8 +42,9 @@ CONFIG = {
"filter_test_users": True, # 只处理含测试样本的用户(跳过会被丢弃的用户,省算力) "filter_test_users": True, # 只处理含测试样本的用户(跳过会被丢弃的用户,省算力)
# 实测:varlen 本地快(10.28s)但评测端慢(148s,嵌套张量构造开销随batch数放大)→已退回。 # 实测:varlen 本地快(10.28s)但评测端慢(148s,嵌套张量构造开销随batch数放大)→已退回。
# sdpa 是评测端验证最快(89.96s/58.86)。flex/compile/小batch/varlen 在评测端都更差。 # sdpa 是评测端验证最快(89.96s/58.86)。flex/compile/小batch/varlen 在评测端都更差。
# attn: "sdpa"(稠密mask,默认/评测最优) / "varlen"(本地快评测慢) / "flex"(慢) # attn: "sdpa"(稠密mask) / "chunked"(按用户分块SDPA,降O(S²)) / "varlen"(评测慢) / "flex"(慢)
"attn": "sdpa", "attn": "sdpa",
"chunk_users": 16, # chunked 模式下每块的用户数(切小拼接序列以降注意力O(S²))
# 稠密MoE去掉了 model(batch) 内唯一的同步点(MoE循环的.nonzero())。若评测计时不 # 稠密MoE去掉了 model(batch) 内唯一的同步点(MoE循环的.nonzero())。若评测计时不
# synchronize,去掉同步点可能让被计时的 model(batch) 大幅缩短。本地force-sync看不出, # synchronize,去掉同步点可能让被计时的 model(batch) 大幅缩短。本地force-sync看不出,
# 须靠提交验证。AUC中性、MoE仅占2%算力故风险极低。 # 须靠提交验证。AUC中性、MoE仅占2%算力故风险极低。
@@ -408,10 +409,19 @@ def _varlen_attention(q, k, v, user_offsets):
def scaled_dot_product(q, k, v, extension): def scaled_dot_product(q, k, v, extension):
"""注意力分发: """注意力分发:
- varlen_offsets → 嵌套张量变长 flash(每用户独立序列、块对角因果,开销)。 - chunks → 按用户分块的 SDPA(每块块内因果,降 O(S²),无嵌套开销)。
- varlen_offsets → 嵌套张量变长 flash(评测端慢,仅对照)。
- block_mask → FlexAttention 块对角因果。 - block_mask → FlexAttention 块对角因果。
- mask(默认) → 标准 SDPA 稠密 mask(数学等价、已验证最快)。 - mask(默认) → 标准 SDPA 稠密 mask(数学等价、已验证最快)。
""" """
if extension is not None and extension.get("chunks") is not None:
outs = []
for s0, s1, m in extension["chunks"]:
outs.append(F.scaled_dot_product_attention(
q[:, :, s0:s1], k[:, :, s0:s1], v[:, :, s0:s1],
attn_mask=m, dropout_p=0.0, is_causal=False))
return torch.cat(outs, dim=2)
if extension is not None and extension.get("varlen_offsets") is not None: if extension is not None and extension.get("varlen_offsets") is not None:
return _varlen_attention(q, k, v, extension["varlen_offsets"]) return _varlen_attention(q, k, v, extension["varlen_offsets"])
@@ -603,6 +613,23 @@ class CTRModel(nn.Module):
out_mask = torch.tril((a == 0).to(torch.int32)).bool() out_mask = torch.tril((a == 0).to(torch.int32)).bool()
return out_mask return out_mask
def build_chunks(self, user_offsets, device):
"""把拼接序列按用户边界切成每块 ~chunk_users 个用户,返回 [(s0,s1,mask), ...]。
每块块内因果,注意力 O(块内S²) 远小于 O(总S²)。仅 1 次同步(读切分边界)。"""
chunk_users = int(CONFIG.get("chunk_users", 16))
B = user_offsets.numel() - 1 # 用户数(读 shape,无同步)
idx = list(range(0, B + 1, chunk_users))
if idx[-1] != B:
idx.append(B)
bounds = user_offsets[idx].tolist() # 1 次同步:取各块的 token 边界
chunks = []
for c in range(len(bounds) - 1):
s0, s1 = bounds[c], bounds[c + 1]
local_off = user_offsets[idx[c]:idx[c + 1] + 1] - s0 # 该块内的用户边界(GPU
m = self.causal_mask_syncfree(local_off, s1 - s0, device).unsqueeze(0).unsqueeze(0)
chunks.append((s0, s1, m))
return chunks
def causal_mask_syncfree(self, user_offsets, S, device): def causal_mask_syncfree(self, user_offsets, S, device):
"""与 get_sequence_causal_mask 等价,但用 searchsorted 求每个位置的用户号, """与 get_sequence_causal_mask 等价,但用 searchsorted 求每个位置的用户号,
避免 repeat_interleave(张量repeats) 的隐式同步。""" 避免 repeat_interleave(张量repeats) 的隐式同步。"""
@@ -628,7 +655,9 @@ class CTRModel(nn.Module):
seq_input = self.rep_encoder(batch) seq_input = self.rep_encoder(batch)
user_offsets = batch["user_offsets"] user_offsets = batch["user_offsets"]
attn = _resolve_attn(seq_input.device) attn = _resolve_attn(seq_input.device)
if attn == "varlen": if attn == "chunked":
extension = {"chunks": self.build_chunks(user_offsets, seq_input.device)}
elif attn == "varlen":
extension = {"varlen_offsets": user_offsets} extension = {"varlen_offsets": user_offsets}
elif attn == "flex": elif attn == "flex":
S = seq_input.shape[0] # rep_encoder 输出 [S, D]S=总 token 数 S = seq_input.shape[0] # rep_encoder 输出 [S, D]S=总 token 数
+23
View File
@@ -64,6 +64,28 @@ def test_moe_dense_matches_loop():
print(f"[PASS] MoE 稠密向量化 == 逐expert循环 (max err={err:.2e}, dev={dev})") print(f"[PASS] MoE 稠密向量化 == 逐expert循环 (max err={err:.2e}, dev={dev})")
def test_chunked_matches_dense_attention():
dev = "cuda" if torch.cuda.is_available() else "cpu"
rep = infer.RepEncoder(vocab_size=100, emb_dim=8, slot_num=28, d_model=8)
seq = infer.TransformerEncoder(d_model=8, n_heads=2, num_layers=1, dim_ff=16)
model = infer.CTRModel(rep, seq, d_model=8).to(dev)
torch.manual_seed(0)
H, Dh = 8, 64
offs = _offsets([10, 25, 7, 40, 18, 5, 33], dev) # 7 个用户
S = int(offs[-1])
q = torch.randn(1, H, S, Dh, device=dev)
k = torch.randn(1, H, S, Dh, device=dev)
v = torch.randn(1, H, S, Dh, device=dev)
with torch.no_grad():
dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]})
infer.CONFIG["chunk_users"] = 3 # 每块 3 个用户
chunks = model.build_chunks(offs, torch.device(dev))
chunked = infer.scaled_dot_product(q, k, v, {"chunks": chunks})
err = (dense - chunked).abs().max().item()
assert torch.allclose(dense, chunked, atol=1e-4, rtol=1e-4), f"chunked 不等价 max err={err:.3e}"
print(f"[PASS] chunked SDPA == 稠密SDPA (max err={err:.2e}, dev={dev})")
def test_syncfree_mask_matches(): def test_syncfree_mask_matches():
dev = "cuda" if torch.cuda.is_available() else "cpu" dev = "cuda" if torch.cuda.is_available() else "cpu"
rep = infer.RepEncoder(vocab_size=100, emb_dim=8, slot_num=28, d_model=8) rep = infer.RepEncoder(vocab_size=100, emb_dim=8, slot_num=28, d_model=8)
@@ -148,6 +170,7 @@ if __name__ == "__main__":
test_moe_dense_matches_loop() test_moe_dense_matches_loop()
test_fused_embedding_matches_perslot() test_fused_embedding_matches_perslot()
test_syncfree_mask_matches() test_syncfree_mask_matches()
test_chunked_matches_dense_attention()
test_varlen_matches_dense_attention() test_varlen_matches_dense_attention()
test_flex_matches_dense_attention() test_flex_matches_dense_attention()
print("[DONE] 等价测试结束") print("[DONE] 等价测试结束")