feat: inference_mode + torch.compile(default)

- torch.no_grad() → torch.inference_mode()(禁梯度+禁版本追踪,更快)
- torch.compile(mode='default'):纯算子融合,不用 CUDA Graph,兼容动态 batch 形状
This commit is contained in:
2026-06-12 22:11:35 +08:00
parent feb71be5bd
commit f7e1fbfbdc
+6 -1
View File
@@ -519,6 +519,11 @@ def load_model(ckpt_path, device='cuda:0'):
model.to(dev) model.to(dev)
model.eval() model.eval()
# === torch.compile(default):算子融合,不用 CUDA Graph,兼容动态 batch 形状 ===
model = torch.compile(model, mode="default")
print("[INFO] torch.compile applied (mode=default)")
print(f"[INFO] Model ready. Device: {dev}") print(f"[INFO] Model ready. Device: {dev}")
return model, dev return model, dev
@@ -703,7 +708,7 @@ def main():
all_probs = [] all_probs = []
time_sum = 0.0 time_sum = 0.0
with torch.no_grad(): with torch.inference_mode():
for batch in tqdm(all_batches, desc="Inference"): for batch in tqdm(all_batches, desc="Inference"):
batch = move_batch_to_device(batch, dev) batch = move_batch_to_device(batch, dev)
pred_mask = batch["pred_mask"].bool() pred_mask = batch["pred_mask"].bool()