feat: torch.compile 单独编译 Expert.forward(fc1→relu→fc2 融合)

- 仅编译 Expert.forward,不碰 MoE 循环和 attention
- 纯静态函数无分支,编译成功率高
- 替代 2:4 稀疏方案
This commit is contained in:
2026-06-13 14:20:01 +08:00
parent 51ef3f66b2
commit a74af49456
+8 -22
View File
@@ -505,30 +505,16 @@ def load_model(ckpt_path, device='cuda:0'):
model.to(dev)
# === 2:4 非结构化稀疏:仅裁剪 Expert FFN 权重,不动 attention/gate ===
# === torch.compile 融合 Expert FFNfc1→relu→fc2),不含动态分支 ===
try:
sp_count = 0
for layer in model.seq_encoder.moe:
for expert in layer.experts:
for attr in ['fc1', 'fc2']:
linear = getattr(expert, attr)
w = linear.weight.data.clone()
shape = w.shape
# 2:4 幅度剪枝:每 4 个连续元素保留 top 2
w_flat = w.reshape(-1, 4)
_, top_idx = torch.topk(w_flat.abs(), k=2, dim=1)
mask = torch.zeros_like(w_flat)
mask.scatter_(1, top_idx, 1.0)
pruned = (w_flat * mask).reshape(shape)
sparse_w = torch.sparse.to_sparse_semi_structured(pruned)
bias = linear.bias
linear.forward = lambda x, sw=sparse_w, b=bias: (
torch.matmul(x, sw.t()) + b if b is not None else torch.matmul(x, sw.t())
)
sp_count += 1
print(f"[INFO] 2:4 sparsity applied to {sp_count} Expert Linear layers")
cc = 0
for moe_layer in model.seq_encoder.moe:
for expert in moe_layer.experts:
expert.forward = torch.compile(expert.forward, mode="reduce-overhead")
cc += 1
print(f"[INFO] torch.compile applied to {cc} Expert.forward methods")
except Exception as e:
print(f"[WARNING] 2:4 sparsity failed ({e}), keeping dense weights")
print(f"[WARNING] Expert torch.compile failed ({e}), using original forward")
model.eval()
print(f"[INFO] Model ready. Device: {dev}")