From 3d28f61a98654da58d511cf3bfe3b074689fc924 Mon Sep 17 00:00:00 2001 From: OwnerSunshine530 Date: Mon, 15 Jun 2026 13:13:13 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=88=86=E5=9D=97SDPA=E6=B3=A8?= =?UTF-8?q?=E6=84=8F=E5=8A=9B(--attn=20chunked)=EF=BC=8C=E6=8C=89=E7=94=A8?= =?UTF-8?q?=E6=88=B7=E8=BE=B9=E7=95=8C=E5=88=87=E5=9D=97=E9=99=8DO(S=C2=B2?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 每块~chunk_users个用户、块内因果SDPA(评测端已验证、无嵌套开销),sum(块S²) 远小于总S²。仅1次同步读切分边界。之前本地bs=16快13%被MoE同步吃掉,现MoE 同步已消除,切块红利应全露出。CONFIG.attn=chunked/chunk_users;等价测试已加。 Co-Authored-By: Claude Opus 4.8 --- 代码/code/bench.py | 7 +++++-- 代码/code/infer.py | 35 ++++++++++++++++++++++++++++++++--- 代码/code/tests/test_equiv.py | 23 +++++++++++++++++++++++ 3 files changed, 60 insertions(+), 5 deletions(-) diff --git a/代码/code/bench.py b/代码/code/bench.py index adcdb21..c0c4f67 100644 --- a/代码/code/bench.py +++ b/代码/code/bench.py @@ -291,8 +291,9 @@ def _parse_args(): help="逗号分隔的 keep_fp32_modules,如 linear,rep_encoder.input_norm") ap.add_argument("--feasign-none", action="store_true", help="不截断特征(max_feasign_per_slot=None)") - ap.add_argument("--attn", choices=["sdpa", "flex", "varlen"], default=None, - help="注意力:sdpa=稠密(原), flex=FlexAttention, varlen=嵌套张量变长flash") + ap.add_argument("--attn", choices=["sdpa", "chunked", "flex", "varlen"], default=None, + help="注意力:sdpa=稠密, chunked=按用户分块SDPA, flex/varlen=对照") + ap.add_argument("--chunk-users", type=int, default=None, help="chunked 模式每块用户数") ap.add_argument("--moe", choices=["dense", "loop"], default=None, help="MoE实现:dense=向量化(新), loop=逐expert循环(原)") ap.add_argument("--compile", action="store_true", help="开启 torch.compile") @@ -324,6 +325,8 @@ if __name__ == "__main__": cfg["keep_fp32_modules"] = tuple(x for x in a.keep.split(",") if x) if a.attn is not None: cfg["attn"] = a.attn + if a.chunk_users is not None: + cfg["chunk_users"] = a.chunk_users if a.moe is not None: cfg["vectorize_moe"] = (a.moe == "dense") if a.emb_fp16: diff --git a/代码/code/infer.py b/代码/code/infer.py index a5b033c..45eadb0 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -42,8 +42,9 @@ CONFIG = { "filter_test_users": True, # 只处理含测试样本的用户(跳过会被丢弃的用户,省算力) # 实测:varlen 本地快(10.28s)但评测端慢(148s,嵌套张量构造开销随batch数放大)→已退回。 # sdpa 是评测端验证最快(89.96s/58.86)。flex/compile/小batch/varlen 在评测端都更差。 - # attn: "sdpa"(稠密mask,默认/评测最优) / "varlen"(本地快评测慢) / "flex"(慢) + # attn: "sdpa"(稠密mask) / "chunked"(按用户分块SDPA,降O(S²)) / "varlen"(评测慢) / "flex"(慢) "attn": "sdpa", + "chunk_users": 16, # chunked 模式下每块的用户数(切小拼接序列以降注意力O(S²)) # 稠密MoE去掉了 model(batch) 内唯一的同步点(MoE循环的.nonzero())。若评测计时不 # synchronize,去掉同步点可能让被计时的 model(batch) 大幅缩短。本地force-sync看不出, # 须靠提交验证。AUC中性、MoE仅占2%算力故风险极低。 @@ -408,10 +409,19 @@ def _varlen_attention(q, k, v, user_offsets): def scaled_dot_product(q, k, v, extension): """注意力分发: - - varlen_offsets → 嵌套张量变长 flash(每用户独立序列、块对角因果,开销低)。 + - chunks → 按用户分块的 SDPA(每块块内因果,降 O(S²),无嵌套开销)。 + - varlen_offsets → 嵌套张量变长 flash(评测端慢,仅对照)。 - block_mask → FlexAttention 块对角因果。 - mask(默认) → 标准 SDPA 稠密 mask(数学等价、已验证最快)。 """ + if extension is not None and extension.get("chunks") is not None: + outs = [] + for s0, s1, m in extension["chunks"]: + outs.append(F.scaled_dot_product_attention( + q[:, :, s0:s1], k[:, :, s0:s1], v[:, :, s0:s1], + attn_mask=m, dropout_p=0.0, is_causal=False)) + return torch.cat(outs, dim=2) + if extension is not None and extension.get("varlen_offsets") is not None: return _varlen_attention(q, k, v, extension["varlen_offsets"]) @@ -603,6 +613,23 @@ class CTRModel(nn.Module): out_mask = torch.tril((a == 0).to(torch.int32)).bool() return out_mask + def build_chunks(self, user_offsets, device): + """把拼接序列按用户边界切成每块 ~chunk_users 个用户,返回 [(s0,s1,mask), ...]。 + 每块块内因果,注意力 O(块内S²) 远小于 O(总S²)。仅 1 次同步(读切分边界)。""" + chunk_users = int(CONFIG.get("chunk_users", 16)) + B = user_offsets.numel() - 1 # 用户数(读 shape,无同步) + idx = list(range(0, B + 1, chunk_users)) + if idx[-1] != B: + idx.append(B) + bounds = user_offsets[idx].tolist() # 1 次同步:取各块的 token 边界 + chunks = [] + for c in range(len(bounds) - 1): + s0, s1 = bounds[c], bounds[c + 1] + local_off = user_offsets[idx[c]:idx[c + 1] + 1] - s0 # 该块内的用户边界(GPU) + m = self.causal_mask_syncfree(local_off, s1 - s0, device).unsqueeze(0).unsqueeze(0) + chunks.append((s0, s1, m)) + return chunks + def causal_mask_syncfree(self, user_offsets, S, device): """与 get_sequence_causal_mask 等价,但用 searchsorted 求每个位置的用户号, 避免 repeat_interleave(张量repeats) 的隐式同步。""" @@ -628,7 +655,9 @@ class CTRModel(nn.Module): seq_input = self.rep_encoder(batch) user_offsets = batch["user_offsets"] attn = _resolve_attn(seq_input.device) - if attn == "varlen": + if attn == "chunked": + extension = {"chunks": self.build_chunks(user_offsets, seq_input.device)} + elif attn == "varlen": extension = {"varlen_offsets": user_offsets} elif attn == "flex": S = seq_input.shape[0] # rep_encoder 输出 [S, D],S=总 token 数 diff --git a/代码/code/tests/test_equiv.py b/代码/code/tests/test_equiv.py index 2cb0d99..4ecd922 100644 --- a/代码/code/tests/test_equiv.py +++ b/代码/code/tests/test_equiv.py @@ -64,6 +64,28 @@ def test_moe_dense_matches_loop(): print(f"[PASS] MoE 稠密向量化 == 逐expert循环 (max err={err:.2e}, dev={dev})") +def test_chunked_matches_dense_attention(): + dev = "cuda" if torch.cuda.is_available() else "cpu" + rep = infer.RepEncoder(vocab_size=100, emb_dim=8, slot_num=28, d_model=8) + seq = infer.TransformerEncoder(d_model=8, n_heads=2, num_layers=1, dim_ff=16) + model = infer.CTRModel(rep, seq, d_model=8).to(dev) + torch.manual_seed(0) + H, Dh = 8, 64 + offs = _offsets([10, 25, 7, 40, 18, 5, 33], dev) # 7 个用户 + S = int(offs[-1]) + q = torch.randn(1, H, S, Dh, device=dev) + k = torch.randn(1, H, S, Dh, device=dev) + v = torch.randn(1, H, S, Dh, device=dev) + with torch.no_grad(): + dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]}) + infer.CONFIG["chunk_users"] = 3 # 每块 3 个用户 + chunks = model.build_chunks(offs, torch.device(dev)) + chunked = infer.scaled_dot_product(q, k, v, {"chunks": chunks}) + err = (dense - chunked).abs().max().item() + assert torch.allclose(dense, chunked, atol=1e-4, rtol=1e-4), f"chunked 不等价 max err={err:.3e}" + print(f"[PASS] chunked SDPA == 稠密SDPA (max err={err:.2e}, dev={dev})") + + def test_syncfree_mask_matches(): dev = "cuda" if torch.cuda.is_available() else "cpu" rep = infer.RepEncoder(vocab_size=100, emb_dim=8, slot_num=28, d_model=8) @@ -148,6 +170,7 @@ if __name__ == "__main__": test_moe_dense_matches_loop() test_fused_embedding_matches_perslot() test_syncfree_mask_matches() + test_chunked_matches_dense_attention() test_varlen_matches_dense_attention() test_flex_matches_dense_attention() print("[DONE] 等价测试结束")