perf: _triton_block_meta 消除最后一个host同步(grid用shape派生上界,空block在kernel内mask空跑)
repeat_interleave(张量repeats)的D2H同步换成searchsorted+shape派生grid上界(S//BLOCK_M+n_seq+1)。 对真实block的blk_seq/blk_inseq与原实现一致;空block blk_inseq=0仅1次空迭代。延续'消同步'(最赚方向)。 Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
+19
-11
@@ -89,17 +89,24 @@ if _HAS_TRITON:
|
||||
tl.store(o_ptrs, acc.to(tl.float16), mask=q_mask[:, None])
|
||||
|
||||
|
||||
def _triton_block_meta(user_offsets, BLOCK_M, device):
|
||||
"""从 user_offsets 算 block→段映射(每 batch 一次、8 层复用;含 1 次同步读 total_blocks)。"""
|
||||
def _triton_block_meta(user_offsets, BLOCK_M, device, S):
|
||||
"""从 user_offsets 算 block→段映射。**无 host 同步**:grid 用 shape 派生的上界
|
||||
grid_upper=S//BLOCK_M+n_seq+1(≥真实 total_blocks),超出的空 block 在 kernel 内被
|
||||
mask 空跑(blk_inseq=0 → 仅 1 次空迭代)。对真实 block 的 (blk_seq,blk_inseq) 与原实现一致。"""
|
||||
cu = user_offsets.to(torch.int32)
|
||||
n_seq = cu.numel() - 1 # shape,无同步
|
||||
seqlens = (cu[1:] - cu[:-1]).to(torch.int64)
|
||||
blocks_per = (seqlens + BLOCK_M - 1) // BLOCK_M
|
||||
n_seq = seqlens.numel()
|
||||
blk_seq = torch.repeat_interleave(torch.arange(n_seq, device=device), blocks_per)
|
||||
total_blocks = blk_seq.numel()
|
||||
starts = torch.cumsum(blocks_per, 0) - blocks_per
|
||||
blk_inseq = torch.arange(total_blocks, device=device) - starts[blk_seq]
|
||||
return cu.contiguous(), blk_seq.to(torch.int32).contiguous(), blk_inseq.to(torch.int32).contiguous(), total_blocks
|
||||
blocks_per = (seqlens + BLOCK_M - 1) // BLOCK_M # [n_seq] GPU
|
||||
cum = torch.cumsum(blocks_per, 0) # cum[i]=前 i+1 个用户的块数
|
||||
cum_prev = cum - blocks_per # 用户 i 之前的块数
|
||||
grid_upper = S // BLOCK_M + n_seq + 1 # HOST int(S,n_seq 来自 shape)
|
||||
b_ids = torch.arange(grid_upper, device=device)
|
||||
blk_seq = torch.searchsorted(cum, b_ids, right=True) # [grid_upper];空块→n_seq
|
||||
safe = blk_seq.clamp(max=n_seq - 1)
|
||||
blk_inseq = torch.where(blk_seq < n_seq, b_ids - cum_prev[safe], torch.zeros_like(b_ids))
|
||||
cu_pad = torch.cat([cu, cu[-1:]]) # [n_seq+2],cu_pad[n_seq+1]=S → 空块空区间
|
||||
return (cu_pad.contiguous(), blk_seq.to(torch.int32).contiguous(),
|
||||
blk_inseq.to(torch.int32).contiguous(), grid_upper)
|
||||
|
||||
|
||||
def _triton_varlen_attn(q, k, v, meta):
|
||||
@@ -912,7 +919,8 @@ class CTRModel(nn.Module):
|
||||
user_offsets = batch["user_offsets"]
|
||||
attn = _resolve_attn(seq_input.device)
|
||||
if attn == "triton":
|
||||
meta = _triton_block_meta(user_offsets, CONFIG.get("triton_block_m", 64), seq_input.device)
|
||||
meta = _triton_block_meta(user_offsets, CONFIG.get("triton_block_m", 64),
|
||||
seq_input.device, seq_input.shape[0])
|
||||
extension = {"triton_meta": meta}
|
||||
elif attn == "chunked":
|
||||
extension = {"chunks": self.build_chunks(user_offsets, seq_input.device)}
|
||||
@@ -1151,7 +1159,7 @@ def load_model(ckpt_path, device='cuda:0'):
|
||||
H, Dh = model.seq_encoder.n_heads, model.seq_encoder.head_dim
|
||||
dummy_off = torch.tensor([0, 64, 130], device=dev)
|
||||
dq = torch.randn(1, H, 130, Dh, device=dev, dtype=torch.float16)
|
||||
meta = _triton_block_meta(dummy_off, CONFIG.get("triton_block_m", 64), dev)
|
||||
meta = _triton_block_meta(dummy_off, CONFIG.get("triton_block_m", 64), dev, 130)
|
||||
_triton_varlen_attn(dq, dq, dq, meta)
|
||||
torch.cuda.synchronize()
|
||||
print("[INFO] triton kernel warmed up")
|
||||
|
||||
@@ -150,7 +150,7 @@ def test_triton_varlen_matches_dense():
|
||||
v = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16)
|
||||
with torch.no_grad():
|
||||
dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]})
|
||||
meta = infer._triton_block_meta(offs, 64, q.device)
|
||||
meta = infer._triton_block_meta(offs, 64, q.device, S)
|
||||
trit = infer.scaled_dot_product(q, k, v, {"triton_meta": meta})
|
||||
err = (dense.float() - trit.float()).abs().max().item()
|
||||
assert torch.allclose(dense.float(), trit.float(), atol=3e-2, rtol=3e-2), \
|
||||
|
||||
Reference in New Issue
Block a user