feat: inference_mode + torch.compile(default)
- torch.no_grad() → torch.inference_mode()(禁梯度+禁版本追踪,更快) - torch.compile(mode='default'):纯算子融合,不用 CUDA Graph,兼容动态 batch 形状
This commit is contained in:
+6
-1
@@ -519,6 +519,11 @@ def load_model(ckpt_path, device='cuda:0'):
|
|||||||
|
|
||||||
model.to(dev)
|
model.to(dev)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
# === torch.compile(default):算子融合,不用 CUDA Graph,兼容动态 batch 形状 ===
|
||||||
|
model = torch.compile(model, mode="default")
|
||||||
|
print("[INFO] torch.compile applied (mode=default)")
|
||||||
|
|
||||||
print(f"[INFO] Model ready. Device: {dev}")
|
print(f"[INFO] Model ready. Device: {dev}")
|
||||||
|
|
||||||
return model, dev
|
return model, dev
|
||||||
@@ -703,7 +708,7 @@ def main():
|
|||||||
all_probs = []
|
all_probs = []
|
||||||
time_sum = 0.0
|
time_sum = 0.0
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.inference_mode():
|
||||||
for batch in tqdm(all_batches, desc="Inference"):
|
for batch in tqdm(all_batches, desc="Inference"):
|
||||||
batch = move_batch_to_device(batch, dev)
|
batch = move_batch_to_device(batch, dev)
|
||||||
pred_mask = batch["pred_mask"].bool()
|
pred_mask = batch["pred_mask"].bool()
|
||||||
|
|||||||
Reference in New Issue
Block a user