diff --git a/代码/code/infer.py b/代码/code/infer.py index 7436b70..77f5250 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -505,16 +505,12 @@ def load_model(ckpt_path, device='cuda:0'): model.to(dev) - # === torch.compile 融合 Expert FFN(fc1→relu→fc2),不含动态分支 === + # === torch.compile + dynamic=True:告知编译器形状可变,避免重编译 === try: - cc = 0 - for moe_layer in model.seq_encoder.moe: - for expert in moe_layer.experts: - expert.forward = torch.compile(expert.forward, mode="default") - cc += 1 - print(f"[INFO] torch.compile applied to {cc} Expert.forward methods") + model = torch.compile(model, dynamic=True) + print(f"[INFO] torch.compile applied (dynamic=True)") except Exception as e: - print(f"[WARNING] Expert torch.compile failed ({e}), using original forward") + print(f"[WARNING] torch.compile failed ({e}), using original model") model.eval() print(f"[INFO] Model ready. Device: {dev}")