diff --git a/代码/code/infer.py b/代码/code/infer.py index d18c819..7436b70 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -510,7 +510,7 @@ def load_model(ckpt_path, device='cuda:0'): cc = 0 for moe_layer in model.seq_encoder.moe: for expert in moe_layer.experts: - expert.forward = torch.compile(expert.forward, mode="reduce-overhead") + expert.forward = torch.compile(expert.forward, mode="default") cc += 1 print(f"[INFO] torch.compile applied to {cc} Expert.forward methods") except Exception as e: