diff --git a/CLAUDE.md b/CLAUDE.md index b84b543..d50610d 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -167,7 +167,7 @@ Baseline 数据:推理 229s,AUC 0.759,PCOC 1.110,得分 25.85。 4. ✅ **inference_mode()** — 替代 `no_grad()`,92.5s(+2s 小幅提升) 5. ❌ **torch.compile** — reduce-overhead 和 default 模式均因动态 batch 形状反效果,彻底放弃 6. ❌ **MoE Top-1 gating** — PCOC 从 1.059 炸到 2.075,已回退 -7. 🔲 **2:4 结构化稀疏** — A800 原生加速,权重形状不变(显式允许) +7. ❌ **2:4 结构化稀疏** — PCOC 炸到 2.067,耗时反增 265s。to_sparse_semi_structured 与 nn.Linear 不兼容 CUDA Graph 已评估并放弃(batch 形状不固定,不适用)。 @@ -190,8 +190,9 @@ CUDA Graph 已评估并放弃(batch 形状不固定,不适用)。 | 日期 | 提交次数 | 得分 | AUC | PCOC | 耗时 | 优化手段 | 备注 | |------|----------|------|-----|------|------|----------|------| +| 06/13 | 11 | 0 | 0.748 | 2.067 | 265.5s | 2:4 sparse | ❌ 炸毁 | | 06/13 | 10 | **57.45** | 0.7526 | 1.059 | 92.5s | + inference_mode | **当前最优** | -| 06/13 | 9 | 51.42 | 0.7525 | 1.059 | 118.4s | + compile(default) | 反效果,已移除 | +| 06/13 | 9 | 51.42 | 0.7525 | 1.059 | 118.4s | + compile(default) | 反效果 | | 06/12 | 8 | 0 | 0.736 | 2.075 | 119.6s | MoE k=1 + compile | PCOC 炸毁 | | 06/12 | 6 | 56.98 | 0.7526 | 1.059 | 94.5s | + Flash Attention | | | 06/12 | 3 | 43.55 | 0.7525 | 1.059 | 152s | + FP16 量化 | | diff --git a/代码/code/infer.py b/代码/code/infer.py index ab80694..d2109c5 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -511,28 +511,7 @@ def load_model(ckpt_path, device='cuda:0'): model.to(dev) model.eval() - - # === 2:4 结构化稀疏:所有 Linear 层权重剪枝,A800 原生 2x 加速 === - try: - sp_count = 0 - for name, module in model.named_modules(): - if isinstance(module, nn.Linear) and module.weight.dim() == 2: - w = module.weight.data - shape = w.shape - # 每 4 个连续元素保留幅度最大的 2 个 - w_flat = w.reshape(-1, 4) - _, top_idx = torch.topk(w_flat.abs(), k=2, dim=1) - mask = torch.zeros_like(w_flat) - mask.scatter_(1, top_idx, 1.0) - pruned = (w_flat * mask).reshape(shape) - # 转为半结构化稀疏格式(A800 SM80 硬件加速) - module.weight = nn.Parameter( - torch.sparse.to_sparse_semi_structured(pruned) - ) - sp_count += 1 - print(f"[INFO] 2:4 sparsity applied to {sp_count} Linear layers") - except Exception as e: - print(f"[WARNING] 2:4 sparsity failed ({e}), continuing with dense weights") + print(f"[INFO] Model ready. Device: {dev}") return model, dev