perf: searchsorted 构造因果mask,消除最后一个同步点(repeat_interleave张量repeats)
dense MoE 去掉MoE的nonzero同步省了评测20s;embedding融合(无同步)只省1s ->真正的杠杆是消同步点。mask构造的repeat_interleave(lengths张量)是model(batch) 内最后一个同步点,改用searchsorted求doc_id(输出size已知,无同步)。等价测试已加。 Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -64,6 +64,19 @@ def test_moe_dense_matches_loop():
|
||||
print(f"[PASS] MoE 稠密向量化 == 逐expert循环 (max err={err:.2e}, dev={dev})")
|
||||
|
||||
|
||||
def test_syncfree_mask_matches():
|
||||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
rep = infer.RepEncoder(vocab_size=100, emb_dim=8, slot_num=28, d_model=8)
|
||||
seq = infer.TransformerEncoder(d_model=8, n_heads=2, num_layers=1, dim_ff=16)
|
||||
model = infer.CTRModel(rep, seq, d_model=8).to(dev)
|
||||
offs = torch.tensor([0, 10, 35, 42, 60], device=dev) # 4 个用户,变长
|
||||
S = int(offs[-1])
|
||||
m1 = model.get_sequence_causal_mask(offs)
|
||||
m2 = model.causal_mask_syncfree(offs, S, torch.device(dev))
|
||||
assert torch.equal(m1, m2), "sync-free mask 与原 mask 不一致"
|
||||
print(f"[PASS] searchsorted mask == repeat_interleave mask (dev={dev})")
|
||||
|
||||
|
||||
def test_varlen_matches_dense_attention():
|
||||
if not torch.cuda.is_available():
|
||||
print("[SKIP] varlen 等价测试(需 CUDA)")
|
||||
@@ -134,6 +147,7 @@ def test_flex_matches_dense_attention():
|
||||
if __name__ == "__main__":
|
||||
test_moe_dense_matches_loop()
|
||||
test_fused_embedding_matches_perslot()
|
||||
test_syncfree_mask_matches()
|
||||
test_varlen_matches_dense_attention()
|
||||
test_flex_matches_dense_attention()
|
||||
print("[DONE] 等价测试结束")
|
||||
|
||||
Reference in New Issue
Block a user