diff --git a/代码/code/infer.py b/代码/code/infer.py index 29ff02a..70e2775 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -519,6 +519,11 @@ def load_model(ckpt_path, device='cuda:0'): model.to(dev) model.eval() + + # === torch.compile(default):算子融合,不用 CUDA Graph,兼容动态 batch 形状 === + model = torch.compile(model, mode="default") + print("[INFO] torch.compile applied (mode=default)") + print(f"[INFO] Model ready. Device: {dev}") return model, dev @@ -703,7 +708,7 @@ def main(): all_probs = [] time_sum = 0.0 - with torch.no_grad(): + with torch.inference_mode(): for batch in tqdm(all_batches, desc="Inference"): batch = move_batch_to_device(batch, dev) pred_mask = batch["pred_mask"].bool()