diff --git a/代码/code/infer.py b/代码/code/infer.py index 691f95d..09e3a44 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -511,11 +511,6 @@ def load_model(ckpt_path, device='cuda:0'): model.to(dev) model.eval() - - # === torch.compile:算子融合 + 减少 kernel launch 开销 === - model = torch.compile(model, mode="reduce-overhead") - print("[INFO] torch.compile applied (mode=reduce-overhead)") - print(f"[INFO] Model ready. Device: {dev}") return model, dev