From f7e1fbfbdc7f1924007b1bb5c2b4ba935958372a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E8=88=AA=E5=AE=87?= <3364451258@qq.com> Date: Fri, 12 Jun 2026 22:11:35 +0800 Subject: [PATCH] feat: inference_mode + torch.compile(default) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - torch.no_grad() → torch.inference_mode()(禁梯度+禁版本追踪,更快) - torch.compile(mode='default'):纯算子融合,不用 CUDA Graph,兼容动态 batch 形状 --- 代码/code/infer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/代码/code/infer.py b/代码/code/infer.py index 29ff02a..70e2775 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -519,6 +519,11 @@ def load_model(ckpt_path, device='cuda:0'): model.to(dev) model.eval() + + # === torch.compile(default):算子融合,不用 CUDA Graph,兼容动态 batch 形状 === + model = torch.compile(model, mode="default") + print("[INFO] torch.compile applied (mode=default)") + print(f"[INFO] Model ready. Device: {dev}") return model, dev @@ -703,7 +708,7 @@ def main(): all_probs = [] time_sum = 0.0 - with torch.no_grad(): + with torch.inference_mode(): for batch in tqdm(all_batches, desc="Inference"): batch = move_batch_to_device(batch, dev) pred_mask = batch["pred_mask"].bool()