feat: 2:4 结构化稀疏(A800 原生加速,所有 Linear 层权重剪枝)
- 每 4 个连续权重保留幅度最大的 2 个(50% 稀疏度) - torch.sparse.to_sparse_semi_structured 硬件加速 matmul - 权重形状不变,属参数级剪枝,合规 - try-except 保护:稀疏化失败时回退 dense 权重
This commit is contained in:
+22
-1
@@ -511,7 +511,28 @@ def load_model(ckpt_path, device='cuda:0'):
|
|||||||
|
|
||||||
model.to(dev)
|
model.to(dev)
|
||||||
model.eval()
|
model.eval()
|
||||||
print(f"[INFO] Model ready. Device: {dev}")
|
|
||||||
|
# === 2:4 结构化稀疏:所有 Linear 层权重剪枝,A800 原生 2x 加速 ===
|
||||||
|
try:
|
||||||
|
sp_count = 0
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if isinstance(module, nn.Linear) and module.weight.dim() == 2:
|
||||||
|
w = module.weight.data
|
||||||
|
shape = w.shape
|
||||||
|
# 每 4 个连续元素保留幅度最大的 2 个
|
||||||
|
w_flat = w.reshape(-1, 4)
|
||||||
|
_, top_idx = torch.topk(w_flat.abs(), k=2, dim=1)
|
||||||
|
mask = torch.zeros_like(w_flat)
|
||||||
|
mask.scatter_(1, top_idx, 1.0)
|
||||||
|
pruned = (w_flat * mask).reshape(shape)
|
||||||
|
# 转为半结构化稀疏格式(A800 SM80 硬件加速)
|
||||||
|
module.weight = nn.Parameter(
|
||||||
|
torch.sparse.to_sparse_semi_structured(pruned)
|
||||||
|
)
|
||||||
|
sp_count += 1
|
||||||
|
print(f"[INFO] 2:4 sparsity applied to {sp_count} Linear layers")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[WARNING] 2:4 sparsity failed ({e}), continuing with dense weights")
|
||||||
|
|
||||||
return model, dev
|
return model, dev
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user