feat: Triton BLOCK_M 可调(triton_block_m,默认64);bench --triton-bm 扫描

突破:triton评测39.92s/69.72(vs chunked 47.84/67.998)。继续调BLOCK_M榨。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-17 13:01:50 +08:00
parent 6f7ff9fce8
commit 1083aca9fa
2 changed files with 7 additions and 3 deletions
+4 -3
View File
@@ -106,7 +106,7 @@ def _triton_varlen_attn(q, k, v, meta):
"""q,k,v: [1, H, S, Dh]contiguous)。meta=(cu, blk_seq, blk_inseq, total_blocks)。返回 [1,H,S,Dh]。"""
_, H, S, Dh = q.shape
cu, blk_seq, blk_inseq, total_blocks = meta
BLOCK_M = 64
BLOCK_M = CONFIG.get("triton_block_m", 64)
out = torch.empty_like(q)
qc = q.contiguous(); kc = k.contiguous(); vc = v.contiguous(); out = out.contiguous()
stride_h, stride_s, stride_d = S * Dh, Dh, 1
@@ -136,6 +136,7 @@ CONFIG = {
# sdpa 是评测端验证最快(89.96s/58.86)。flex/compile/小batch/varlen 在评测端都更差。
# attn: "chunked"(按用户分块SDPA,降O(S²),本地14.25->7.92s) / "sdpa"(稠密mask) / 其它对照
"attn": "triton", # Triton varlen flash(单kernel,消逐块调用/mask构造开销);无triton回退chunked
"triton_block_m": 64, # Triton query 块大小(可调 32/64/128;块大=调用少)
"chunk_users": 4, # chunked 回退时用;评测扫描 3/4/8 中 4 最优(47.84s/67.998)
# 稠密MoE去掉了 model(batch) 内唯一的同步点(MoE循环的.nonzero())。若评测计时不
# synchronize,去掉同步点可能让被计时的 model(batch) 大幅缩短。本地force-sync看不出,
@@ -835,7 +836,7 @@ class CTRModel(nn.Module):
user_offsets = batch["user_offsets"]
attn = _resolve_attn(seq_input.device)
if attn == "triton":
meta = _triton_block_meta(user_offsets, 64, seq_input.device)
meta = _triton_block_meta(user_offsets, CONFIG.get("triton_block_m", 64), seq_input.device)
extension = {"triton_meta": meta}
elif attn == "chunked":
extension = {"chunks": self.build_chunks(user_offsets, seq_input.device)}
@@ -1071,7 +1072,7 @@ def load_model(ckpt_path, device='cuda:0'):
H, Dh = model.seq_encoder.n_heads, model.seq_encoder.head_dim
dummy_off = torch.tensor([0, 64, 130], device=dev)
dq = torch.randn(1, H, 130, Dh, device=dev, dtype=torch.float16)
meta = _triton_block_meta(dummy_off, 64, dev)
meta = _triton_block_meta(dummy_off, CONFIG.get("triton_block_m", 64), dev)
_triton_varlen_attn(dq, dq, dq, meta)
torch.cuda.synchronize()
print("[INFO] triton kernel warmed up")