feat: torch.compile 全模型 + dynamic=True(告知编译器形状可变,避免重编译)
This commit is contained in:
+4
-8
@@ -505,16 +505,12 @@ def load_model(ckpt_path, device='cuda:0'):
|
||||
|
||||
model.to(dev)
|
||||
|
||||
# === torch.compile 融合 Expert FFN(fc1→relu→fc2),不含动态分支 ===
|
||||
# === torch.compile + dynamic=True:告知编译器形状可变,避免重编译 ===
|
||||
try:
|
||||
cc = 0
|
||||
for moe_layer in model.seq_encoder.moe:
|
||||
for expert in moe_layer.experts:
|
||||
expert.forward = torch.compile(expert.forward, mode="default")
|
||||
cc += 1
|
||||
print(f"[INFO] torch.compile applied to {cc} Expert.forward methods")
|
||||
model = torch.compile(model, dynamic=True)
|
||||
print(f"[INFO] torch.compile applied (dynamic=True)")
|
||||
except Exception as e:
|
||||
print(f"[WARNING] Expert torch.compile failed ({e}), using original forward")
|
||||
print(f"[WARNING] torch.compile failed ({e}), using original model")
|
||||
|
||||
model.eval()
|
||||
print(f"[INFO] Model ready. Device: {dev}")
|
||||
|
||||
Reference in New Issue
Block a user