From 43b0c6c92a1dbe27d6f7e8f1feede2b8089df52a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E8=88=AA=E5=AE=87?= <3364451258@qq.com> Date: Sat, 13 Jun 2026 12:20:40 +0800 Subject: [PATCH] =?UTF-8?q?feat:=202:4=20=E7=BB=93=E6=9E=84=E5=8C=96?= =?UTF-8?q?=E7=A8=80=E7=96=8F=EF=BC=88A800=20=E5=8E=9F=E7=94=9F=E5=8A=A0?= =?UTF-8?q?=E9=80=9F=EF=BC=8C=E6=89=80=E6=9C=89=20Linear=20=E5=B1=82?= =?UTF-8?q?=E6=9D=83=E9=87=8D=E5=89=AA=E6=9E=9D=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 每 4 个连续权重保留幅度最大的 2 个(50% 稀疏度) - torch.sparse.to_sparse_semi_structured 硬件加速 matmul - 权重形状不变,属参数级剪枝,合规 - try-except 保护:稀疏化失败时回退 dense 权重 --- 代码/code/infer.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/代码/code/infer.py b/代码/code/infer.py index d2109c5..ab80694 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -511,7 +511,28 @@ def load_model(ckpt_path, device='cuda:0'): model.to(dev) model.eval() - print(f"[INFO] Model ready. Device: {dev}") + + # === 2:4 结构化稀疏:所有 Linear 层权重剪枝,A800 原生 2x 加速 === + try: + sp_count = 0 + for name, module in model.named_modules(): + if isinstance(module, nn.Linear) and module.weight.dim() == 2: + w = module.weight.data + shape = w.shape + # 每 4 个连续元素保留幅度最大的 2 个 + w_flat = w.reshape(-1, 4) + _, top_idx = torch.topk(w_flat.abs(), k=2, dim=1) + mask = torch.zeros_like(w_flat) + mask.scatter_(1, top_idx, 1.0) + pruned = (w_flat * mask).reshape(shape) + # 转为半结构化稀疏格式(A800 SM80 硬件加速) + module.weight = nn.Parameter( + torch.sparse.to_sparse_semi_structured(pruned) + ) + sp_count += 1 + print(f"[INFO] 2:4 sparsity applied to {sp_count} Linear layers") + except Exception as e: + print(f"[WARNING] 2:4 sparsity failed ({e}), continuing with dense weights") return model, dev