From 6f7ff9fce8948eb6335651d2098f2f2b291c3070 Mon Sep 17 00:00:00 2001 From: OwnerSunshine530 Date: Wed, 17 Jun 2026 12:23:11 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20Triton=20kernel=20load=5Fmodel=E9=A2=84?= =?UTF-8?q?=E7=83=AD(=E9=81=BF=E5=85=8D=E9=A6=96batch=E5=90=ABJIT=E7=BC=96?= =?UTF-8?q?=E8=AF=91)=20+=20=E9=BB=98=E8=AE=A4attn=3Dtriton?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Opus 4.8 --- 代码/code/infer.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/代码/code/infer.py b/代码/code/infer.py index aa7f442..eb44099 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -135,8 +135,8 @@ CONFIG = { # 实测:varlen 本地快(10.28s)但评测端慢(148s,嵌套张量构造开销随batch数放大)→已退回。 # sdpa 是评测端验证最快(89.96s/58.86)。flex/compile/小batch/varlen 在评测端都更差。 # attn: "chunked"(按用户分块SDPA,降O(S²),本地14.25->7.92s) / "sdpa"(稠密mask) / 其它对照 - "attn": "chunked", - "chunk_users": 4, # 评测扫描 3/4/8:chunk=4 最优(47.84s/67.998),3更慢8持平→此维度榨干 + "attn": "triton", # Triton varlen flash(单kernel,消逐块调用/mask构造开销);无triton回退chunked + "chunk_users": 4, # chunked 回退时用;评测扫描 3/4/8 中 4 最优(47.84s/67.998) # 稠密MoE去掉了 model(batch) 内唯一的同步点(MoE循环的.nonzero())。若评测计时不 # synchronize,去掉同步点可能让被计时的 model(batch) 大幅缩短。本地force-sync看不出, # 须靠提交验证。AUC中性、MoE仅占2%算力故风险极低。 @@ -1064,6 +1064,20 @@ def load_model(ckpt_path, device='cuda:0'): global _MODEL_REF _MODEL_REF = model # 供 collate_fn 就地算 RepEncoder + + # 预热 Triton kernel(不计时阶段触发 JIT 编译,避免首个 model(batch) 含编译时间) + if _resolve_attn(dev) == "triton": + try: + H, Dh = model.seq_encoder.n_heads, model.seq_encoder.head_dim + dummy_off = torch.tensor([0, 64, 130], device=dev) + dq = torch.randn(1, H, 130, Dh, device=dev, dtype=torch.float16) + meta = _triton_block_meta(dummy_off, 64, dev) + _triton_varlen_attn(dq, dq, dq, meta) + torch.cuda.synchronize() + print("[INFO] triton kernel warmed up") + except Exception as e: + print(f"[WARNING] triton warmup failed ({e})") + print(f"[INFO] Model ready. Device: {dev}") return model, dev