feat: torch.compile 全模型 + dynamic=True(告知编译器形状可变,避免重编译)
This commit is contained in:
+4
-8
@@ -505,16 +505,12 @@ def load_model(ckpt_path, device='cuda:0'):
|
|||||||
|
|
||||||
model.to(dev)
|
model.to(dev)
|
||||||
|
|
||||||
# === torch.compile 融合 Expert FFN(fc1→relu→fc2),不含动态分支 ===
|
# === torch.compile + dynamic=True:告知编译器形状可变,避免重编译 ===
|
||||||
try:
|
try:
|
||||||
cc = 0
|
model = torch.compile(model, dynamic=True)
|
||||||
for moe_layer in model.seq_encoder.moe:
|
print(f"[INFO] torch.compile applied (dynamic=True)")
|
||||||
for expert in moe_layer.experts:
|
|
||||||
expert.forward = torch.compile(expert.forward, mode="default")
|
|
||||||
cc += 1
|
|
||||||
print(f"[INFO] torch.compile applied to {cc} Expert.forward methods")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[WARNING] Expert torch.compile failed ({e}), using original forward")
|
print(f"[WARNING] torch.compile failed ({e}), using original model")
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
print(f"[INFO] Model ready. Device: {dev}")
|
print(f"[INFO] Model ready. Device: {dev}")
|
||||||
|
|||||||
Reference in New Issue
Block a user