diff --git a/代码/code/bench.py b/代码/code/bench.py index 0557e16..d94fc02 100644 --- a/代码/code/bench.py +++ b/代码/code/bench.py @@ -315,8 +315,8 @@ def _parse_args(): help="逗号分隔的 keep_fp32_modules,如 linear,rep_encoder.input_norm") ap.add_argument("--feasign-none", action="store_true", help="不截断特征(max_feasign_per_slot=None)") - ap.add_argument("--attn", choices=["sdpa", "chunked", "flex", "varlen"], default=None, - help="注意力:sdpa=稠密, chunked=按用户分块SDPA, flex/varlen=对照") + ap.add_argument("--attn", choices=["sdpa", "chunked", "triton", "flex", "varlen"], default=None, + help="注意力:sdpa=稠密, chunked=分块SDPA, triton=varlen flash kernel, flex/varlen=对照") ap.add_argument("--chunk-users", type=int, default=None, help="chunked 模式每块用户数") ap.add_argument("--moe", choices=["dense", "loop"], default=None, help="MoE实现:dense=向量化(新), loop=逐expert循环(原)") diff --git a/代码/code/infer.py b/代码/code/infer.py index 7e172c2..394d3b7 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -26,6 +26,98 @@ except Exception: create_block_mask = None _HAS_FLEX = False +# Triton varlen 因果 flash attention(块对角,单 kernel,消除逐块调用/mask 构造开销) +try: + import triton + import triton.language as tl + _HAS_TRITON = True +except Exception: + triton = None + tl = None + _HAS_TRITON = False + + +if _HAS_TRITON: + @triton.jit + def _varlen_flash_fwd( + Q, K, V, Out, + cu_seqlens, blk_seq, blk_inseq, + stride_h, stride_s, stride_d, + scale, n_seq, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, D: tl.constexpr, + ): + pid = tl.program_id(0) # 全局 query 块 + h = tl.program_id(1) # head + s = tl.load(blk_seq + pid) + bis = tl.load(blk_inseq + pid) + seq_start = tl.load(cu_seqlens + s) + seq_end = tl.load(cu_seqlens + s + 1) + + q_row0 = seq_start + bis * BLOCK_M + offs_m = q_row0 + tl.arange(0, BLOCK_M) # query token 全局行号 + offs_d = tl.arange(0, D) + q_mask = offs_m < seq_end + q_ptrs = Q + h * stride_h + offs_m[:, None] * stride_s + offs_d[None, :] * stride_d + q = tl.load(q_ptrs, mask=q_mask[:, None], other=0.0).to(tl.float32) + + m_i = tl.full([BLOCK_M], -float("inf"), tl.float32) + l_i = tl.zeros([BLOCK_M], tl.float32) + acc = tl.zeros([BLOCK_M, D], tl.float32) + + q_pos = offs_m - seq_start # query 段内位置 + kv_end = q_row0 + BLOCK_M # 因果:key 不超过本 query 块末尾 + for kn in range(seq_start, kv_end, BLOCK_N): + offs_n = kn + tl.arange(0, BLOCK_N) + k_mask = offs_n < seq_end + k_ptrs = K + h * stride_h + offs_n[:, None] * stride_s + offs_d[None, :] * stride_d + k = tl.load(k_ptrs, mask=k_mask[:, None], other=0.0).to(tl.float32) + qk = tl.dot(q, tl.trans(k)) * scale # [BLOCK_M, BLOCK_N] + k_pos = offs_n - seq_start + valid = (q_pos[:, None] >= k_pos[None, :]) & k_mask[None, :] + qk = tl.where(valid, qk, -float("inf")) + m_new = tl.maximum(m_i, tl.max(qk, 1)) + p = tl.exp(qk - m_new[:, None]) + alpha = tl.exp(m_i - m_new) + l_i = l_i * alpha + tl.sum(p, 1) + v_ptrs = V + h * stride_h + offs_n[:, None] * stride_s + offs_d[None, :] * stride_d + v = tl.load(v_ptrs, mask=k_mask[:, None], other=0.0).to(tl.float32) + acc = acc * alpha[:, None] + tl.dot(p, v) + m_i = m_new + + acc = acc / l_i[:, None] + o_ptrs = Out + h * stride_h + offs_m[:, None] * stride_s + offs_d[None, :] * stride_d + tl.store(o_ptrs, acc.to(tl.float16), mask=q_mask[:, None]) + + +def _triton_block_meta(user_offsets, BLOCK_M, device): + """从 user_offsets 算 block→段映射(每 batch 一次、8 层复用;含 1 次同步读 total_blocks)。""" + cu = user_offsets.to(torch.int32) + seqlens = (cu[1:] - cu[:-1]).to(torch.int64) + blocks_per = (seqlens + BLOCK_M - 1) // BLOCK_M + n_seq = seqlens.numel() + blk_seq = torch.repeat_interleave(torch.arange(n_seq, device=device), blocks_per) + total_blocks = blk_seq.numel() + starts = torch.cumsum(blocks_per, 0) - blocks_per + blk_inseq = torch.arange(total_blocks, device=device) - starts[blk_seq] + return cu.contiguous(), blk_seq.to(torch.int32).contiguous(), blk_inseq.to(torch.int32).contiguous(), total_blocks + + +def _triton_varlen_attn(q, k, v, meta): + """q,k,v: [1, H, S, Dh](contiguous)。meta=(cu, blk_seq, blk_inseq, total_blocks)。返回 [1,H,S,Dh]。""" + _, H, S, Dh = q.shape + cu, blk_seq, blk_inseq, total_blocks = meta + BLOCK_M = 64 + out = torch.empty_like(q) + qc = q.contiguous(); kc = k.contiguous(); vc = v.contiguous(); out = out.contiguous() + stride_h, stride_s, stride_d = S * Dh, Dh, 1 + grid = (total_blocks, H) + _varlen_flash_fwd[grid]( + qc, kc, vc, out, cu, blk_seq, blk_inseq, + stride_h, stride_s, stride_d, 1.0 / math.sqrt(Dh), cu.numel() - 1, + BLOCK_M=BLOCK_M, BLOCK_N=64, D=Dh, + ) + return out + # ============================================================ # 实验配置开关板 @@ -65,12 +157,17 @@ CONFIG = { def _resolve_attn(device): - """解析实际使用的注意力实现。flex 需 SM80+ 且可用,否则回退 sdpa。""" + """解析实际使用的注意力实现。triton/flex 需 CUDA(SM80+ for flex),否则回退 chunked/sdpa。""" attn = CONFIG.get("attn", "sdpa") + is_cuda = device is not None and device.type == "cuda" + if attn == "triton": + if not (_HAS_TRITON and is_cuda): + return "chunked" # Triton 不可用 → 回退已验证的 chunked + return "triton" if attn == "flex": if not _HAS_FLEX: return "sdpa" - if device is not None and device.type == "cuda": + if is_cuda: try: if torch.cuda.get_device_capability(device)[0] < 8: return "sdpa" @@ -475,6 +572,9 @@ def scaled_dot_product(q, k, v, extension): - block_mask → FlexAttention 块对角因果。 - mask(默认) → 标准 SDPA 稠密 mask(数学等价、已验证最快)。 """ + if extension is not None and extension.get("triton_meta") is not None: + return _triton_varlen_attn(q, k, v, extension["triton_meta"]) + if extension is not None and extension.get("chunks") is not None: outs = [] for s0, s1, m in extension["chunks"]: @@ -734,7 +834,10 @@ class CTRModel(nn.Module): seq_input = self.rep_encoder(batch) user_offsets = batch["user_offsets"] attn = _resolve_attn(seq_input.device) - if attn == "chunked": + if attn == "triton": + meta = _triton_block_meta(user_offsets, 64, seq_input.device) + extension = {"triton_meta": meta} + elif attn == "chunked": extension = {"chunks": self.build_chunks(user_offsets, seq_input.device)} elif attn == "varlen": extension = {"varlen_offsets": user_offsets} diff --git a/代码/code/tests/test_equiv.py b/代码/code/tests/test_equiv.py index 1e7b7af..a4fa482 100644 --- a/代码/code/tests/test_equiv.py +++ b/代码/code/tests/test_equiv.py @@ -112,6 +112,28 @@ def test_sparse_pool_matches(): print(f"[PASS] sparse_pool == segment_reduce (max err={err:.2e}, dev={dev})") +def test_triton_varlen_matches_dense(): + if not (torch.cuda.is_available() and infer._HAS_TRITON): + print("[SKIP] Triton varlen 等价测试(需 CUDA + triton)") + return + torch.manual_seed(0) + dev = "cuda" + H, Dh = 8, 64 + offs = _offsets([10, 64, 1, 130, 64, 200], dev) # 含跨多块/单token/正好整块的段 + S = int(offs[-1]) + q = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16) + k = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16) + v = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16) + with torch.no_grad(): + dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]}) + meta = infer._triton_block_meta(offs, 64, q.device) + trit = infer.scaled_dot_product(q, k, v, {"triton_meta": meta}) + err = (dense.float() - trit.float()).abs().max().item() + assert torch.allclose(dense.float(), trit.float(), atol=3e-2, rtol=3e-2), \ + f"Triton varlen 不等价 max err={err:.3e}" + print(f"[PASS] Triton varlen flash == 稠密SDPA (max err={err:.2e})") + + def test_syncfree_mask_matches(): dev = "cuda" if torch.cuda.is_available() else "cpu" rep = infer.RepEncoder(vocab_size=100, emb_dim=8, slot_num=28, d_model=8) @@ -197,6 +219,7 @@ if __name__ == "__main__": test_fused_embedding_matches_perslot() test_sparse_pool_matches() test_syncfree_mask_matches() + test_triton_varlen_matches_dense() test_chunked_matches_dense_attention() test_varlen_matches_dense_attention() test_flex_matches_dense_attention()