diff --git a/代码/code/bench.py b/代码/code/bench.py index d94fc02..0c4cb4f 100644 --- a/代码/code/bench.py +++ b/代码/code/bench.py @@ -318,6 +318,7 @@ def _parse_args(): ap.add_argument("--attn", choices=["sdpa", "chunked", "triton", "flex", "varlen"], default=None, help="注意力:sdpa=稠密, chunked=分块SDPA, triton=varlen flash kernel, flex/varlen=对照") ap.add_argument("--chunk-users", type=int, default=None, help="chunked 模式每块用户数") + ap.add_argument("--triton-bm", type=int, default=None, help="Triton query 块大小(32/64/128)") ap.add_argument("--moe", choices=["dense", "loop"], default=None, help="MoE实现:dense=向量化(新), loop=逐expert循环(原)") ap.add_argument("--compile", action="store_true", help="开启 torch.compile") @@ -361,6 +362,8 @@ if __name__ == "__main__": cfg["attn"] = a.attn if a.chunk_users is not None: cfg["chunk_users"] = a.chunk_users + if a.triton_bm is not None: + cfg["triton_block_m"] = a.triton_bm if a.moe is not None: cfg["vectorize_moe"] = (a.moe == "dense") if a.emb_fp16: diff --git a/代码/code/infer.py b/代码/code/infer.py index eb44099..5854344 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -106,7 +106,7 @@ def _triton_varlen_attn(q, k, v, meta): """q,k,v: [1, H, S, Dh](contiguous)。meta=(cu, blk_seq, blk_inseq, total_blocks)。返回 [1,H,S,Dh]。""" _, H, S, Dh = q.shape cu, blk_seq, blk_inseq, total_blocks = meta - BLOCK_M = 64 + BLOCK_M = CONFIG.get("triton_block_m", 64) out = torch.empty_like(q) qc = q.contiguous(); kc = k.contiguous(); vc = v.contiguous(); out = out.contiguous() stride_h, stride_s, stride_d = S * Dh, Dh, 1 @@ -136,6 +136,7 @@ CONFIG = { # sdpa 是评测端验证最快(89.96s/58.86)。flex/compile/小batch/varlen 在评测端都更差。 # attn: "chunked"(按用户分块SDPA,降O(S²),本地14.25->7.92s) / "sdpa"(稠密mask) / 其它对照 "attn": "triton", # Triton varlen flash(单kernel,消逐块调用/mask构造开销);无triton回退chunked + "triton_block_m": 64, # Triton query 块大小(可调 32/64/128;块大=调用少) "chunk_users": 4, # chunked 回退时用;评测扫描 3/4/8 中 4 最优(47.84s/67.998) # 稠密MoE去掉了 model(batch) 内唯一的同步点(MoE循环的.nonzero())。若评测计时不 # synchronize,去掉同步点可能让被计时的 model(batch) 大幅缩短。本地force-sync看不出, @@ -835,7 +836,7 @@ class CTRModel(nn.Module): user_offsets = batch["user_offsets"] attn = _resolve_attn(seq_input.device) if attn == "triton": - meta = _triton_block_meta(user_offsets, 64, seq_input.device) + meta = _triton_block_meta(user_offsets, CONFIG.get("triton_block_m", 64), seq_input.device) extension = {"triton_meta": meta} elif attn == "chunked": extension = {"chunks": self.build_chunks(user_offsets, seq_input.device)} @@ -1071,7 +1072,7 @@ def load_model(ckpt_path, device='cuda:0'): H, Dh = model.seq_encoder.n_heads, model.seq_encoder.head_dim dummy_off = torch.tensor([0, 64, 130], device=dev) dq = torch.randn(1, H, 130, Dh, device=dev, dtype=torch.float16) - meta = _triton_block_meta(dummy_off, 64, dev) + meta = _triton_block_meta(dummy_off, CONFIG.get("triton_block_m", 64), dev) _triton_varlen_attn(dq, dq, dq, meta) torch.cuda.synchronize() print("[INFO] triton kernel warmed up")