perf: RepEncoder 融合 28-slot 查表+池化为单次(减per-batch kernel启动,无新增同步)
延续 dense MoE 的胜因(消 per-batch 开销在评测端被放大见效)。28次embedding +28次segment_reduce 融合为1次;用 numel 读shape避免同步;base累加无同步。 保留 _rep_forward_perslot 作等价对照。CONFIG.fuse_embedding 默认 True。 Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -85,6 +85,30 @@ def test_varlen_matches_dense_attention():
|
||||
print(f"[PASS] varlen(嵌套张量) == 稠密SDPA (max err={err:.2e})")
|
||||
|
||||
|
||||
def test_fused_embedding_matches_perslot():
|
||||
torch.manual_seed(0)
|
||||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
slot_num, emb_dim, d_model = 28, 512, 512
|
||||
enc = infer.RepEncoder(vocab_size=10000, emb_dim=emb_dim, slot_num=slot_num,
|
||||
d_model=d_model).to(dev).eval()
|
||||
# 造一个 N=6 样本的 batch:每 slot 每样本 0~4 个 sign(含空 slot 边界)
|
||||
N = 6
|
||||
batch = {}
|
||||
for s in range(1, slot_num + 1):
|
||||
counts = torch.randint(0, 5, (N,))
|
||||
vals = torch.randint(0, 10000, (int(counts.sum()),), device=dev)
|
||||
offs = torch.cat([torch.zeros(1, dtype=torch.long), counts.cumsum(0)]).to(dev)
|
||||
batch[s] = (vals, offs)
|
||||
with torch.no_grad():
|
||||
infer.CONFIG["fuse_embedding"] = False
|
||||
ref = enc(batch)
|
||||
infer.CONFIG["fuse_embedding"] = True
|
||||
new = enc(batch)
|
||||
err = (ref - new).abs().max().item()
|
||||
assert torch.allclose(ref, new, atol=1e-4, rtol=1e-4), f"embedding融合不等价 max err={err:.3e}"
|
||||
print(f"[PASS] embedding 融合 == 逐slot (max err={err:.2e}, dev={dev})")
|
||||
|
||||
|
||||
def test_flex_matches_dense_attention():
|
||||
ok = (torch.cuda.is_available() and infer._HAS_FLEX
|
||||
and torch.cuda.get_device_capability()[0] >= 8)
|
||||
@@ -109,6 +133,7 @@ def test_flex_matches_dense_attention():
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_moe_dense_matches_loop()
|
||||
test_fused_embedding_matches_perslot()
|
||||
test_varlen_matches_dense_attention()
|
||||
test_flex_matches_dense_attention()
|
||||
print("[DONE] 等价测试结束")
|
||||
|
||||
Reference in New Issue
Block a user