From 4dbee8309716683ab7e823b0be99137d4abea4c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E8=88=AA=E5=AE=87?= <3364451258@qq.com> Date: Sat, 13 Jun 2026 14:09:42 +0800 Subject: [PATCH] =?UTF-8?q?feat:=202:4=20=E9=9D=9E=E7=BB=93=E6=9E=84?= =?UTF-8?q?=E5=8C=96=E7=A8=80=E7=96=8F=E4=BB=85=E8=A3=81=E5=89=AA=20Expert?= =?UTF-8?q?=20FFN=EF=BC=88=E4=B8=8D=E7=A2=B0=20attention/gate=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 合规:单个权重置零,矩阵形状不变 - 只裁剪 8层×8expert×2fc = 128 个 Expert Linear - lambda forward 直调 sparse matmul,绕开 nn.Linear 兼容问题 --- 代码/code/infer.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/代码/code/infer.py b/代码/code/infer.py index 9bca191..58bef95 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -504,6 +504,32 @@ def load_model(ckpt_path, device='cuda:0'): print(f"[WARNING] Checkpoint {ckpt_path} not found, using random weights") model.to(dev) + + # === 2:4 非结构化稀疏:仅裁剪 Expert FFN 权重,不动 attention/gate === + try: + sp_count = 0 + for layer in model.seq_encoder.moe: + for expert in layer.experts: + for attr in ['fc1', 'fc2']: + linear = getattr(expert, attr) + w = linear.weight.data.clone() + shape = w.shape + # 2:4 幅度剪枝:每 4 个连续元素保留 top 2 + w_flat = w.reshape(-1, 4) + _, top_idx = torch.topk(w_flat.abs(), k=2, dim=1) + mask = torch.zeros_like(w_flat) + mask.scatter_(1, top_idx, 1.0) + pruned = (w_flat * mask).reshape(shape) + sparse_w = torch.sparse.to_sparse_semi_structured(pruned) + bias = linear.bias + linear.forward = lambda x, sw=sparse_w, b=bias: ( + torch.matmul(x, sw.t()) + b if b is not None else torch.matmul(x, sw.t()) + ) + sp_count += 1 + print(f"[INFO] 2:4 sparsity applied to {sp_count} Expert Linear layers") + except Exception as e: + print(f"[WARNING] 2:4 sparsity failed ({e}), keeping dense weights") + model.eval() print(f"[INFO] Model ready. Device: {dev}")