perf: _triton_block_meta 消除最后一个host同步(grid用shape派生上界,空block在kernel内mask空跑)
repeat_interleave(张量repeats)的D2H同步换成searchsorted+shape派生grid上界(S//BLOCK_M+n_seq+1)。 对真实block的blk_seq/blk_inseq与原实现一致;空block blk_inseq=0仅1次空迭代。延续'消同步'(最赚方向)。 Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -150,7 +150,7 @@ def test_triton_varlen_matches_dense():
|
||||
v = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
|
||||
with torch.no_grad():
|
||||
dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]})
|
||||
meta = infer._triton_block_meta(offs, 64, q.device)
|
||||
meta = infer._triton_block_meta(offs, 64, q.device, S)
|
||||
trit = infer.scaled_dot_product(q, k, v, {"triton_meta": meta})
|
||||
err = (dense.float() - trit.float()).abs().max().item()
|
||||
assert torch.allclose(dense.float(), trit.float(), atol=3e-2, rtol=3e-2), \
|
||||
|
||||
Reference in New Issue
Block a user