perf: _triton_block_meta 消除最后一个host同步(grid用shape派生上界,空block在kernel内mask空跑)

repeat_interleave(张量repeats)的D2H同步换成searchsorted+shape派生grid上界(S//BLOCK_M+n_seq+1)。
对真实block的blk_seq/blk_inseq与原实现一致;空block blk_inseq=0仅1次空迭代。延续'消同步'(最赚方向)。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-19 20:51:37 +08:00
parent b72e0346a9
commit 7bb2e0f518
2 changed files with 20 additions and 12 deletions
+19 -11
View File
@@ -89,17 +89,24 @@ if _HAS_TRITON:
tl.store(o_ptrs, acc.to(tl.float16), mask=q_mask[:, None]) tl.store(o_ptrs, acc.to(tl.float16), mask=q_mask[:, None])
def _triton_block_meta(user_offsets, BLOCK_M, device): def _triton_block_meta(user_offsets, BLOCK_M, device, S):
"""从 user_offsets 算 block→段映射(每 batch 一次、8 层复用;含 1 次同步读 total_blocks)。""" """从 user_offsets 算 block→段映射。**无 host 同步**grid 用 shape 派生的上界
grid_upper=S//BLOCK_M+n_seq+1(≥真实 total_blocks),超出的空 block 在 kernel 内被
mask 空跑(blk_inseq=0 → 仅 1 次空迭代)。对真实 block 的 (blk_seq,blk_inseq) 与原实现一致。"""
cu = user_offsets.to(torch.int32) cu = user_offsets.to(torch.int32)
n_seq = cu.numel() - 1 # shape,无同步
seqlens = (cu[1:] - cu[:-1]).to(torch.int64) seqlens = (cu[1:] - cu[:-1]).to(torch.int64)
blocks_per = (seqlens + BLOCK_M - 1) // BLOCK_M blocks_per = (seqlens + BLOCK_M - 1) // BLOCK_M # [n_seq] GPU
n_seq = seqlens.numel() cum = torch.cumsum(blocks_per, 0) # cum[i]=前 i+1 个用户的块数
blk_seq = torch.repeat_interleave(torch.arange(n_seq, device=device), blocks_per) cum_prev = cum - blocks_per # 用户 i 之前的块数
total_blocks = blk_seq.numel() grid_upper = S // BLOCK_M + n_seq + 1 # HOST intS,n_seq 来自 shape
starts = torch.cumsum(blocks_per, 0) - blocks_per b_ids = torch.arange(grid_upper, device=device)
blk_inseq = torch.arange(total_blocks, device=device) - starts[blk_seq] blk_seq = torch.searchsorted(cum, b_ids, right=True) # [grid_upper];空块→n_seq
return cu.contiguous(), blk_seq.to(torch.int32).contiguous(), blk_inseq.to(torch.int32).contiguous(), total_blocks safe = blk_seq.clamp(max=n_seq - 1)
blk_inseq = torch.where(blk_seq < n_seq, b_ids - cum_prev[safe], torch.zeros_like(b_ids))
cu_pad = torch.cat([cu, cu[-1:]]) # [n_seq+2]cu_pad[n_seq+1]=S → 空块空区间
return (cu_pad.contiguous(), blk_seq.to(torch.int32).contiguous(),
blk_inseq.to(torch.int32).contiguous(), grid_upper)
def _triton_varlen_attn(q, k, v, meta): def _triton_varlen_attn(q, k, v, meta):
@@ -912,7 +919,8 @@ class CTRModel(nn.Module):
user_offsets = batch["user_offsets"] user_offsets = batch["user_offsets"]
attn = _resolve_attn(seq_input.device) attn = _resolve_attn(seq_input.device)
if attn == "triton": if attn == "triton":
meta = _triton_block_meta(user_offsets, CONFIG.get("triton_block_m", 64), seq_input.device) meta = _triton_block_meta(user_offsets, CONFIG.get("triton_block_m", 64),
seq_input.device, seq_input.shape[0])
extension = {"triton_meta": meta} extension = {"triton_meta": meta}
elif attn == "chunked": elif attn == "chunked":
extension = {"chunks": self.build_chunks(user_offsets, seq_input.device)} extension = {"chunks": self.build_chunks(user_offsets, seq_input.device)}
@@ -1151,7 +1159,7 @@ def load_model(ckpt_path, device='cuda:0'):
H, Dh = model.seq_encoder.n_heads, model.seq_encoder.head_dim H, Dh = model.seq_encoder.n_heads, model.seq_encoder.head_dim
dummy_off = torch.tensor([0, 64, 130], device=dev) dummy_off = torch.tensor([0, 64, 130], device=dev)
dq = torch.randn(1, H, 130, Dh, device=dev, dtype=torch.float16) dq = torch.randn(1, H, 130, Dh, device=dev, dtype=torch.float16)
meta = _triton_block_meta(dummy_off, CONFIG.get("triton_block_m", 64), dev) meta = _triton_block_meta(dummy_off, CONFIG.get("triton_block_m", 64), dev, 130)
_triton_varlen_attn(dq, dq, dq, meta) _triton_varlen_attn(dq, dq, dq, meta)
torch.cuda.synchronize() torch.cuda.synchronize()
print("[INFO] triton kernel warmed up") print("[INFO] triton kernel warmed up")
+1 -1
View File
@@ -150,7 +150,7 @@ def test_triton_varlen_matches_dense():
v = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16) v = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
with torch.no_grad(): with torch.no_grad():
dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]}) dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]})
meta = infer._triton_block_meta(offs, 64, q.device) meta = infer._triton_block_meta(offs, 64, q.device, S)
trit = infer.scaled_dot_product(q, k, v, {"triton_meta": meta}) trit = infer.scaled_dot_product(q, k, v, {"triton_meta": meta})
err = (dense.float() - trit.float()).abs().max().item() err = (dense.float() - trit.float()).abs().max().item()
assert torch.allclose(dense.float(), trit.float(), atol=3e-2, rtol=3e-2), \ assert torch.allclose(dense.float(), trit.float(), atol=3e-2, rtol=3e-2), \