perf: searchsorted 构造因果mask,消除最后一个同步点(repeat_interleave张量repeats)
dense MoE 去掉MoE的nonzero同步省了评测20s;embedding融合(无同步)只省1s ->真正的杠杆是消同步点。mask构造的repeat_interleave(lengths张量)是model(batch) 内最后一个同步点,改用searchsorted求doc_id(输出size已知,无同步)。等价测试已加。 Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
+16
-2
@@ -49,6 +49,7 @@ CONFIG = {
|
|||||||
# 须靠提交验证。AUC中性、MoE仅占2%算力故风险极低。
|
# 须靠提交验证。AUC中性、MoE仅占2%算力故风险极低。
|
||||||
"vectorize_moe": True, # True=稠密向量化MoE(无同步点);False=原逐expert循环(.nonzero同步)
|
"vectorize_moe": True, # True=稠密向量化MoE(无同步点);False=原逐expert循环(.nonzero同步)
|
||||||
"fuse_embedding": True, # True=28个slot的查表+池化融合为1次(减per-batch kernel启动)
|
"fuse_embedding": True, # True=28个slot的查表+池化融合为1次(减per-batch kernel启动)
|
||||||
|
"syncfree_mask": True, # True=用searchsorted构造因果mask(无同步);False=repeat_interleave(同步)
|
||||||
"compile": False, # 是否 torch.compile(实测慢5×,勿开)
|
"compile": False, # 是否 torch.compile(实测慢5×,勿开)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -596,11 +597,20 @@ class CTRModel(nn.Module):
|
|||||||
lengths = seq_info[1:] - seq_info[:-1]
|
lengths = seq_info[1:] - seq_info[:-1]
|
||||||
lengths = lengths.view(-1)
|
lengths = lengths.view(-1)
|
||||||
indices = torch.cumsum(torch.ones_like(lengths), dim=0) - 1
|
indices = torch.cumsum(torch.ones_like(lengths), dim=0) - 1
|
||||||
result = torch.repeat_interleave(indices, lengths)
|
result = torch.repeat_interleave(indices, lengths) # repeats 是张量 → 同步
|
||||||
a = result.view(1, -1) - result.view(-1, 1)
|
a = result.view(1, -1) - result.view(-1, 1)
|
||||||
out_mask = torch.tril((a == 0).to(torch.int32)).bool()
|
out_mask = torch.tril((a == 0).to(torch.int32)).bool()
|
||||||
return out_mask
|
return out_mask
|
||||||
|
|
||||||
|
def causal_mask_syncfree(self, user_offsets, S, device):
|
||||||
|
"""与 get_sequence_causal_mask 等价,但用 searchsorted 求每个位置的用户号,
|
||||||
|
避免 repeat_interleave(张量repeats) 的隐式同步。"""
|
||||||
|
pos = torch.arange(S, device=device)
|
||||||
|
doc_id = torch.searchsorted(user_offsets[1:].contiguous(), pos, right=True) # [S],无同步
|
||||||
|
same = doc_id.view(-1, 1) == doc_id.view(1, -1)
|
||||||
|
causal = pos.view(-1, 1) >= pos.view(1, -1)
|
||||||
|
return same & causal
|
||||||
|
|
||||||
def build_block_mask(self, user_offsets, S):
|
def build_block_mask(self, user_offsets, S):
|
||||||
"""FlexAttention 块对角因果 mask:q 只能 attend 同一用户且 kv<=q 的位置。"""
|
"""FlexAttention 块对角因果 mask:q 只能 attend 同一用户且 kv<=q 的位置。"""
|
||||||
lengths = (user_offsets[1:] - user_offsets[:-1]).view(-1)
|
lengths = (user_offsets[1:] - user_offsets[:-1]).view(-1)
|
||||||
@@ -623,7 +633,11 @@ class CTRModel(nn.Module):
|
|||||||
S = seq_input.shape[0] # rep_encoder 输出 [S, D],S=总 token 数
|
S = seq_input.shape[0] # rep_encoder 输出 [S, D],S=总 token 数
|
||||||
extension = {"block_mask": self.build_block_mask(user_offsets, S)}
|
extension = {"block_mask": self.build_block_mask(user_offsets, S)}
|
||||||
else:
|
else:
|
||||||
seq_mask = self.get_sequence_causal_mask(user_offsets)
|
if CONFIG.get("syncfree_mask", True):
|
||||||
|
seq_mask = self.causal_mask_syncfree(
|
||||||
|
user_offsets, seq_input.shape[0], seq_input.device)
|
||||||
|
else:
|
||||||
|
seq_mask = self.get_sequence_causal_mask(user_offsets)
|
||||||
extension = {"mask": seq_mask.unsqueeze(0).unsqueeze(0)}
|
extension = {"mask": seq_mask.unsqueeze(0).unsqueeze(0)}
|
||||||
encoder_output, moe_loss = self.seq_encoder(x=seq_input, extension=extension)
|
encoder_output, moe_loss = self.seq_encoder(x=seq_input, extension=extension)
|
||||||
encoder_output = encoder_output.squeeze(0)
|
encoder_output = encoder_output.squeeze(0)
|
||||||
|
|||||||
@@ -64,6 +64,19 @@ def test_moe_dense_matches_loop():
|
|||||||
print(f"[PASS] MoE 稠密向量化 == 逐expert循环 (max err={err:.2e}, dev={dev})")
|
print(f"[PASS] MoE 稠密向量化 == 逐expert循环 (max err={err:.2e}, dev={dev})")
|
||||||
|
|
||||||
|
|
||||||
|
def test_syncfree_mask_matches():
|
||||||
|
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
rep = infer.RepEncoder(vocab_size=100, emb_dim=8, slot_num=28, d_model=8)
|
||||||
|
seq = infer.TransformerEncoder(d_model=8, n_heads=2, num_layers=1, dim_ff=16)
|
||||||
|
model = infer.CTRModel(rep, seq, d_model=8).to(dev)
|
||||||
|
offs = torch.tensor([0, 10, 35, 42, 60], device=dev) # 4 个用户,变长
|
||||||
|
S = int(offs[-1])
|
||||||
|
m1 = model.get_sequence_causal_mask(offs)
|
||||||
|
m2 = model.causal_mask_syncfree(offs, S, torch.device(dev))
|
||||||
|
assert torch.equal(m1, m2), "sync-free mask 与原 mask 不一致"
|
||||||
|
print(f"[PASS] searchsorted mask == repeat_interleave mask (dev={dev})")
|
||||||
|
|
||||||
|
|
||||||
def test_varlen_matches_dense_attention():
|
def test_varlen_matches_dense_attention():
|
||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
print("[SKIP] varlen 等价测试(需 CUDA)")
|
print("[SKIP] varlen 等价测试(需 CUDA)")
|
||||||
@@ -134,6 +147,7 @@ def test_flex_matches_dense_attention():
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_moe_dense_matches_loop()
|
test_moe_dense_matches_loop()
|
||||||
test_fused_embedding_matches_perslot()
|
test_fused_embedding_matches_perslot()
|
||||||
|
test_syncfree_mask_matches()
|
||||||
test_varlen_matches_dense_attention()
|
test_varlen_matches_dense_attention()
|
||||||
test_flex_matches_dense_attention()
|
test_flex_matches_dense_attention()
|
||||||
print("[DONE] 等价测试结束")
|
print("[DONE] 等价测试结束")
|
||||||
|
|||||||
Reference in New Issue
Block a user