feat: torch.compile 全模型 + dynamic=True(告知编译器形状可变,避免重编译)

This commit is contained in:
2026-06-13 14:37:38 +08:00
parent 480a81a033
commit 7b429cf7fb
+4 -8
View File
@@ -505,16 +505,12 @@ def load_model(ckpt_path, device='cuda:0'):
model.to(dev) model.to(dev)
# === torch.compile 融合 Expert FFNfc1→relu→fc2),不含动态分支 === # === torch.compile + dynamic=True:告知编译器形状可变,避免重编译 ===
try: try:
cc = 0 model = torch.compile(model, dynamic=True)
for moe_layer in model.seq_encoder.moe: print(f"[INFO] torch.compile applied (dynamic=True)")
for expert in moe_layer.experts:
expert.forward = torch.compile(expert.forward, mode="default")
cc += 1
print(f"[INFO] torch.compile applied to {cc} Expert.forward methods")
except Exception as e: except Exception as e:
print(f"[WARNING] Expert torch.compile failed ({e}), using original forward") print(f"[WARNING] torch.compile failed ({e}), using original model")
model.eval() model.eval()
print(f"[INFO] Model ready. Device: {dev}") print(f"[INFO] Model ready. Device: {dev}")