perf: _triton_block_meta 消除最后一个host同步(grid用shape派生上界,空block在kernel内mask空跑)

repeat_interleave(张量repeats)的D2H同步换成searchsorted+shape派生grid上界(S//BLOCK_M+n_seq+1)。
对真实block的blk_seq/blk_inseq与原实现一致;空block blk_inseq=0仅1次空迭代。延续'消同步'(最赚方向)。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-19 20:51:37 +08:00
parent b72e0346a9
commit 7bb2e0f518
2 changed files with 20 additions and 12 deletions
+1 -1
View File
@@ -150,7 +150,7 @@ def test_triton_varlen_matches_dense():
v = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
with torch.no_grad():
dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]})
meta = infer._triton_block_meta(offs, 64, q.device)
meta = infer._triton_block_meta(offs, 64, q.device, S)
trit = infer.scaled_dot_product(q, k, v, {"triton_meta": meta})
err = (dense.float() - trit.float()).abs().max().item()
assert torch.allclose(dense.float(), trit.float(), atol=3e-2, rtol=3e-2), \