diff --git a/代码/code/infer.py b/代码/code/infer.py index 29dcac2..18740ec 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -89,17 +89,24 @@ if _HAS_TRITON: tl.store(o_ptrs, acc.to(tl.float16), mask=q_mask[:, None]) -def _triton_block_meta(user_offsets, BLOCK_M, device): - """从 user_offsets 算 block→段映射(每 batch 一次、8 层复用;含 1 次同步读 total_blocks)。""" +def _triton_block_meta(user_offsets, BLOCK_M, device, S): + """从 user_offsets 算 block→段映射。**无 host 同步**:grid 用 shape 派生的上界 + grid_upper=S//BLOCK_M+n_seq+1(≥真实 total_blocks),超出的空 block 在 kernel 内被 + mask 空跑(blk_inseq=0 → 仅 1 次空迭代)。对真实 block 的 (blk_seq,blk_inseq) 与原实现一致。""" cu = user_offsets.to(torch.int32) + n_seq = cu.numel() - 1 # shape,无同步 seqlens = (cu[1:] - cu[:-1]).to(torch.int64) - blocks_per = (seqlens + BLOCK_M - 1) // BLOCK_M - n_seq = seqlens.numel() - blk_seq = torch.repeat_interleave(torch.arange(n_seq, device=device), blocks_per) - total_blocks = blk_seq.numel() - starts = torch.cumsum(blocks_per, 0) - blocks_per - blk_inseq = torch.arange(total_blocks, device=device) - starts[blk_seq] - return cu.contiguous(), blk_seq.to(torch.int32).contiguous(), blk_inseq.to(torch.int32).contiguous(), total_blocks + blocks_per = (seqlens + BLOCK_M - 1) // BLOCK_M # [n_seq] GPU + cum = torch.cumsum(blocks_per, 0) # cum[i]=前 i+1 个用户的块数 + cum_prev = cum - blocks_per # 用户 i 之前的块数 + grid_upper = S // BLOCK_M + n_seq + 1 # HOST int(S,n_seq 来自 shape) + b_ids = torch.arange(grid_upper, device=device) + blk_seq = torch.searchsorted(cum, b_ids, right=True) # [grid_upper];空块→n_seq + safe = blk_seq.clamp(max=n_seq - 1) + blk_inseq = torch.where(blk_seq < n_seq, b_ids - cum_prev[safe], torch.zeros_like(b_ids)) + cu_pad = torch.cat([cu, cu[-1:]]) # [n_seq+2],cu_pad[n_seq+1]=S → 空块空区间 + return (cu_pad.contiguous(), blk_seq.to(torch.int32).contiguous(), + blk_inseq.to(torch.int32).contiguous(), grid_upper) def _triton_varlen_attn(q, k, v, meta): @@ -912,7 +919,8 @@ class CTRModel(nn.Module): user_offsets = batch["user_offsets"] attn = _resolve_attn(seq_input.device) if attn == "triton": - meta = _triton_block_meta(user_offsets, CONFIG.get("triton_block_m", 64), seq_input.device) + meta = _triton_block_meta(user_offsets, CONFIG.get("triton_block_m", 64), + seq_input.device, seq_input.shape[0]) extension = {"triton_meta": meta} elif attn == "chunked": extension = {"chunks": self.build_chunks(user_offsets, seq_input.device)} @@ -1151,7 +1159,7 @@ def load_model(ckpt_path, device='cuda:0'): H, Dh = model.seq_encoder.n_heads, model.seq_encoder.head_dim dummy_off = torch.tensor([0, 64, 130], device=dev) dq = torch.randn(1, H, 130, Dh, device=dev, dtype=torch.float16) - meta = _triton_block_meta(dummy_off, CONFIG.get("triton_block_m", 64), dev) + meta = _triton_block_meta(dummy_off, CONFIG.get("triton_block_m", 64), dev, 130) _triton_varlen_attn(dq, dq, dq, meta) torch.cuda.synchronize() print("[INFO] triton kernel warmed up") diff --git a/代码/code/tests/test_equiv.py b/代码/code/tests/test_equiv.py index 53e6407..b9b9aa3 100644 --- a/代码/code/tests/test_equiv.py +++ b/代码/code/tests/test_equiv.py @@ -150,7 +150,7 @@ def test_triton_varlen_matches_dense(): v = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16) with torch.no_grad(): dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]}) - meta = infer._triton_block_meta(offs, 64, q.device) + meta = infer._triton_block_meta(offs, 64, q.device, S) trit = infer.scaled_dot_product(q, k, v, {"triton_meta": meta}) err = (dense.float() - trit.float()).abs().max().item() assert torch.allclose(dense.float(), trit.float(), atol=3e-2, rtol=3e-2), \