From cb2913cda851a2724397e5af449a88b20f006217 Mon Sep 17 00:00:00 2001 From: OwnerSunshine530 Date: Mon, 15 Jun 2026 12:09:40 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20searchsorted=20=E6=9E=84=E9=80=A0?= =?UTF-8?q?=E5=9B=A0=E6=9E=9Cmask=EF=BC=8C=E6=B6=88=E9=99=A4=E6=9C=80?= =?UTF-8?q?=E5=90=8E=E4=B8=80=E4=B8=AA=E5=90=8C=E6=AD=A5=E7=82=B9(repeat?= =?UTF-8?q?=5Finterleave=E5=BC=A0=E9=87=8Frepeats)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit dense MoE 去掉MoE的nonzero同步省了评测20s;embedding融合(无同步)只省1s ->真正的杠杆是消同步点。mask构造的repeat_interleave(lengths张量)是model(batch) 内最后一个同步点,改用searchsorted求doc_id(输出size已知,无同步)。等价测试已加。 Co-Authored-By: Claude Opus 4.8 --- 代码/code/infer.py | 18 ++++++++++++++++-- 代码/code/tests/test_equiv.py | 14 ++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/代码/code/infer.py b/代码/code/infer.py index ff1b64b..7564797 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -49,6 +49,7 @@ CONFIG = { # 须靠提交验证。AUC中性、MoE仅占2%算力故风险极低。 "vectorize_moe": True, # True=稠密向量化MoE(无同步点);False=原逐expert循环(.nonzero同步) "fuse_embedding": True, # True=28个slot的查表+池化融合为1次(减per-batch kernel启动) + "syncfree_mask": True, # True=用searchsorted构造因果mask(无同步);False=repeat_interleave(同步) "compile": False, # 是否 torch.compile(实测慢5×,勿开) } @@ -596,11 +597,20 @@ class CTRModel(nn.Module): lengths = seq_info[1:] - seq_info[:-1] lengths = lengths.view(-1) indices = torch.cumsum(torch.ones_like(lengths), dim=0) - 1 - result = torch.repeat_interleave(indices, lengths) + result = torch.repeat_interleave(indices, lengths) # repeats 是张量 → 同步 a = result.view(1, -1) - result.view(-1, 1) out_mask = torch.tril((a == 0).to(torch.int32)).bool() return out_mask + def causal_mask_syncfree(self, user_offsets, S, device): + """与 get_sequence_causal_mask 等价,但用 searchsorted 求每个位置的用户号, + 避免 repeat_interleave(张量repeats) 的隐式同步。""" + pos = torch.arange(S, device=device) + doc_id = torch.searchsorted(user_offsets[1:].contiguous(), pos, right=True) # [S],无同步 + same = doc_id.view(-1, 1) == doc_id.view(1, -1) + causal = pos.view(-1, 1) >= pos.view(1, -1) + return same & causal + def build_block_mask(self, user_offsets, S): """FlexAttention 块对角因果 mask:q 只能 attend 同一用户且 kv<=q 的位置。""" lengths = (user_offsets[1:] - user_offsets[:-1]).view(-1) @@ -623,7 +633,11 @@ class CTRModel(nn.Module): S = seq_input.shape[0] # rep_encoder 输出 [S, D],S=总 token 数 extension = {"block_mask": self.build_block_mask(user_offsets, S)} else: - seq_mask = self.get_sequence_causal_mask(user_offsets) + if CONFIG.get("syncfree_mask", True): + seq_mask = self.causal_mask_syncfree( + user_offsets, seq_input.shape[0], seq_input.device) + else: + seq_mask = self.get_sequence_causal_mask(user_offsets) extension = {"mask": seq_mask.unsqueeze(0).unsqueeze(0)} encoder_output, moe_loss = self.seq_encoder(x=seq_input, extension=extension) encoder_output = encoder_output.squeeze(0) diff --git a/代码/code/tests/test_equiv.py b/代码/code/tests/test_equiv.py index dcbcc81..2cb0d99 100644 --- a/代码/code/tests/test_equiv.py +++ b/代码/code/tests/test_equiv.py @@ -64,6 +64,19 @@ def test_moe_dense_matches_loop(): print(f"[PASS] MoE 稠密向量化 == 逐expert循环 (max err={err:.2e}, dev={dev})") +def test_syncfree_mask_matches(): + dev = "cuda" if torch.cuda.is_available() else "cpu" + rep = infer.RepEncoder(vocab_size=100, emb_dim=8, slot_num=28, d_model=8) + seq = infer.TransformerEncoder(d_model=8, n_heads=2, num_layers=1, dim_ff=16) + model = infer.CTRModel(rep, seq, d_model=8).to(dev) + offs = torch.tensor([0, 10, 35, 42, 60], device=dev) # 4 个用户,变长 + S = int(offs[-1]) + m1 = model.get_sequence_causal_mask(offs) + m2 = model.causal_mask_syncfree(offs, S, torch.device(dev)) + assert torch.equal(m1, m2), "sync-free mask 与原 mask 不一致" + print(f"[PASS] searchsorted mask == repeat_interleave mask (dev={dev})") + + def test_varlen_matches_dense_attention(): if not torch.cuda.is_available(): print("[SKIP] varlen 等价测试(需 CUDA)") @@ -134,6 +147,7 @@ def test_flex_matches_dense_attention(): if __name__ == "__main__": test_moe_dense_matches_loop() test_fused_embedding_matches_perslot() + test_syncfree_mask_matches() test_varlen_matches_dense_attention() test_flex_matches_dense_attention() print("[DONE] 等价测试结束")