From a74af49456cbb64130e08d761f752c12eac57b1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E8=88=AA=E5=AE=87?= <3364451258@qq.com> Date: Sat, 13 Jun 2026 14:20:01 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20torch.compile=20=E5=8D=95=E7=8B=AC?= =?UTF-8?q?=E7=BC=96=E8=AF=91=20Expert.forward=EF=BC=88fc1=E2=86=92relu?= =?UTF-8?q?=E2=86=92fc2=20=E8=9E=8D=E5=90=88=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 仅编译 Expert.forward,不碰 MoE 循环和 attention - 纯静态函数无分支,编译成功率高 - 替代 2:4 稀疏方案 --- 代码/code/infer.py | 30 ++++++++---------------------- 1 file changed, 8 insertions(+), 22 deletions(-) diff --git a/代码/code/infer.py b/代码/code/infer.py index 58bef95..d18c819 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -505,30 +505,16 @@ def load_model(ckpt_path, device='cuda:0'): model.to(dev) - # === 2:4 非结构化稀疏:仅裁剪 Expert FFN 权重,不动 attention/gate === + # === torch.compile 融合 Expert FFN(fc1→relu→fc2),不含动态分支 === try: - sp_count = 0 - for layer in model.seq_encoder.moe: - for expert in layer.experts: - for attr in ['fc1', 'fc2']: - linear = getattr(expert, attr) - w = linear.weight.data.clone() - shape = w.shape - # 2:4 幅度剪枝:每 4 个连续元素保留 top 2 - w_flat = w.reshape(-1, 4) - _, top_idx = torch.topk(w_flat.abs(), k=2, dim=1) - mask = torch.zeros_like(w_flat) - mask.scatter_(1, top_idx, 1.0) - pruned = (w_flat * mask).reshape(shape) - sparse_w = torch.sparse.to_sparse_semi_structured(pruned) - bias = linear.bias - linear.forward = lambda x, sw=sparse_w, b=bias: ( - torch.matmul(x, sw.t()) + b if b is not None else torch.matmul(x, sw.t()) - ) - sp_count += 1 - print(f"[INFO] 2:4 sparsity applied to {sp_count} Expert Linear layers") + cc = 0 + for moe_layer in model.seq_encoder.moe: + for expert in moe_layer.experts: + expert.forward = torch.compile(expert.forward, mode="reduce-overhead") + cc += 1 + print(f"[INFO] torch.compile applied to {cc} Expert.forward methods") except Exception as e: - print(f"[WARNING] 2:4 sparsity failed ({e}), keeping dense weights") + print(f"[WARNING] Expert torch.compile failed ({e}), using original forward") model.eval() print(f"[INFO] Model ready. Device: {dev}")