feat: infer.py 接入 CONFIG 实验开关 + 新增 bench.py 测量闭环

- infer.py: 模块级 CONFIG(fp16/keep_fp32_modules/expert_merge/
  merge_threshold/signid_mode/sync_timing),默认值=当前最优行为;
  load_model 按 CONFIG 控制半精度/FP32敏感层/expert合并;
  RepEncoder 支持 clamp/modulo 两种 sign-id 处理;
  新增 _force_fp32_io 钩子让敏感层在FP16模型里以FP32 IO 计算。
- bench.py: 设置 CONFIG → 跑推理 → cuda.synchronize 真实计时 →
  _cal_score 打印 AUC/PCOC/延迟/总分,支持配置/batch扫描。不进提交包。
- EXPERIMENTS.md: 实验记录表。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-14 16:48:38 +08:00
parent 0bd6ec440d
commit 9d5a5a52f2
3 changed files with 185 additions and 6 deletions
+56 -6
View File
@@ -18,6 +18,41 @@ from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
# ============================================================
# 实验配置开关板
# 提交时保持下面的默认值 = 当前最优行为;评测系统不碰它,按默认值跑。
# bench.py 会在 import 之后用 infer.CONFIG.update(...) 覆盖这些值。
# ============================================================
CONFIG = {
"fp16": True, # True=半精度推理;False=FP32 参考跑(确立 AUC 天花板)
"keep_fp32_modules": (), # fp16 下仍保留 FP32 的子模块名前缀,如 ("linear",)
"expert_merge": True, # 是否做 expert 权重相似度合并
"merge_threshold": 0.90, # 合并的余弦相似度阈值
"signid_mode": "clamp", # "clamp" 或 "modulo":处理超界 sign id 的方式
"sync_timing": False, # bench 里设 True,做 torch.cuda.synchronize 真实计时
}
def _force_fp32_io(module):
"""让某个模块在 FP16 模型里以 FP32 计算:输入转 FP32、输出转回 FP16。
用于 keep_fp32_modules 指定的精度敏感层(如最终输出头、LayerNorm)。"""
module.float()
def _pre(m, args):
return tuple(
a.float() if torch.is_tensor(a) and a.is_floating_point() else a
for a in args
)
def _post(m, args, output):
if torch.is_tensor(output) and output.is_floating_point():
return output.half()
return output
module.register_forward_pre_hook(_pre)
module.register_forward_hook(_post)
# ============================================================
# 数据加载(来自 train/dataset.py
# ============================================================
@@ -263,7 +298,10 @@ class RepEncoder(nn.Module):
for i in range(self.slot_num):
values, offsets = batch[i + 1]
offsets = offsets.to(values.device)
values = values.clamp(0, max_idx) # 超出 vocab_size 的 sign id 截断,避免越界
if CONFIG["signid_mode"] == "modulo":
values = values % self.emb.num_embeddings # 取模哈希(与训练一致时用)
else:
values = values.clamp(0, max_idx) # 超出 vocab_size 的 sign id 截断,避免越界
sign_emb = self.emb(values).to(target_dtype)
res = torch.segment_reduce(sign_emb, reduce='sum', offsets=offsets, initial=0)
pooled_embs.append(res)
@@ -496,13 +534,25 @@ def load_model(ckpt_path, device='cuda:0'):
model.load_state_dict(ckpt['model_state_dict'])
print(f"[INFO] Loaded checkpoint from {ckpt_path} (epoch={ckpt.get('epoch', '?')})")
# === FP16 量化:模型参数转半精度,Embedding 保留 FP32 ===
model = model.half()
model.rep_encoder.emb = model.rep_encoder.emb.to(torch.float32)
print("[INFO] Model converted to FP16 (embedding kept in FP32)")
if CONFIG["fp16"]:
model = model.half()
# Embedding 始终保留 FP32(int 索引查表,不受浮点精度影响)
model.rep_encoder.emb = model.rep_encoder.emb.to(torch.float32)
# 额外保留 FP32 的精度敏感模块(输入/输出自动转换)
for name, module in model.named_modules():
if name and any(name.startswith(p) for p in CONFIG["keep_fp32_modules"]):
_force_fp32_io(module)
print(f"[INFO] FP16 on; FP32-kept: "
f"{('rep_encoder.emb',) + tuple(CONFIG['keep_fp32_modules'])}")
else:
model = model.float()
print("[INFO] FP32 reference (no half)")
# === 按 Expert 权重相似度合并冗余 expert ===
_merge_experts(model, sim_threshold=0.90)
if CONFIG["expert_merge"]:
_merge_experts(model, sim_threshold=CONFIG["merge_threshold"])
else:
print("[INFO] expert_merge off")
else:
print(f"[WARNING] Checkpoint {ckpt_path} not found, using random weights")