feat: 2:4 非结构化稀疏仅裁剪 Expert FFN(不碰 attention/gate)
- 合规:单个权重置零,矩阵形状不变 - 只裁剪 8层×8expert×2fc = 128 个 Expert Linear - lambda forward 直调 sparse matmul,绕开 nn.Linear 兼容问题
This commit is contained in:
@@ -504,6 +504,32 @@ def load_model(ckpt_path, device='cuda:0'):
|
|||||||
print(f"[WARNING] Checkpoint {ckpt_path} not found, using random weights")
|
print(f"[WARNING] Checkpoint {ckpt_path} not found, using random weights")
|
||||||
|
|
||||||
model.to(dev)
|
model.to(dev)
|
||||||
|
|
||||||
|
# === 2:4 非结构化稀疏:仅裁剪 Expert FFN 权重,不动 attention/gate ===
|
||||||
|
try:
|
||||||
|
sp_count = 0
|
||||||
|
for layer in model.seq_encoder.moe:
|
||||||
|
for expert in layer.experts:
|
||||||
|
for attr in ['fc1', 'fc2']:
|
||||||
|
linear = getattr(expert, attr)
|
||||||
|
w = linear.weight.data.clone()
|
||||||
|
shape = w.shape
|
||||||
|
# 2:4 幅度剪枝:每 4 个连续元素保留 top 2
|
||||||
|
w_flat = w.reshape(-1, 4)
|
||||||
|
_, top_idx = torch.topk(w_flat.abs(), k=2, dim=1)
|
||||||
|
mask = torch.zeros_like(w_flat)
|
||||||
|
mask.scatter_(1, top_idx, 1.0)
|
||||||
|
pruned = (w_flat * mask).reshape(shape)
|
||||||
|
sparse_w = torch.sparse.to_sparse_semi_structured(pruned)
|
||||||
|
bias = linear.bias
|
||||||
|
linear.forward = lambda x, sw=sparse_w, b=bias: (
|
||||||
|
torch.matmul(x, sw.t()) + b if b is not None else torch.matmul(x, sw.t())
|
||||||
|
)
|
||||||
|
sp_count += 1
|
||||||
|
print(f"[INFO] 2:4 sparsity applied to {sp_count} Expert Linear layers")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[WARNING] 2:4 sparsity failed ({e}), keeping dense weights")
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
print(f"[INFO] Model ready. Device: {dev}")
|
print(f"[INFO] Model ready. Device: {dev}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user