revert: 移除 2:4 稀疏(PCOC 2.067 + 耗时反增 265s,to_sparse_semi_structured 与 nn.Linear 不兼容)

回退到稳定版:FP16 + Flash Attention + inference_mode(57.45 分)
This commit is contained in:
2026-06-13 12:34:29 +08:00
parent e6519b7b1a
commit e69ba714e5
2 changed files with 4 additions and 24 deletions
+1 -22
View File
@@ -511,28 +511,7 @@ def load_model(ckpt_path, device='cuda:0'):
model.to(dev)
model.eval()
# === 2:4 结构化稀疏:所有 Linear 层权重剪枝,A800 原生 2x 加速 ===
try:
sp_count = 0
for name, module in model.named_modules():
if isinstance(module, nn.Linear) and module.weight.dim() == 2:
w = module.weight.data
shape = w.shape
# 每 4 个连续元素保留幅度最大的 2 个
w_flat = w.reshape(-1, 4)
_, top_idx = torch.topk(w_flat.abs(), k=2, dim=1)
mask = torch.zeros_like(w_flat)
mask.scatter_(1, top_idx, 1.0)
pruned = (w_flat * mask).reshape(shape)
# 转为半结构化稀疏格式(A800 SM80 硬件加速)
module.weight = nn.Parameter(
torch.sparse.to_sparse_semi_structured(pruned)
)
sp_count += 1
print(f"[INFO] 2:4 sparsity applied to {sp_count} Linear layers")
except Exception as e:
print(f"[WARNING] 2:4 sparsity failed ({e}), continuing with dense weights")
print(f"[INFO] Model ready. Device: {dev}")
return model, dev