diff --git a/代码/code/infer.py b/代码/code/infer.py index 9bca191..6df0dd4 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -500,6 +500,9 @@ def load_model(ckpt_path, device='cuda:0'): model = model.half() model.rep_encoder.emb = model.rep_encoder.emb.to(torch.float32) print("[INFO] Model converted to FP16 (embedding kept in FP32)") + + # === 按 Expert 权重相似度合并冗余 expert === + _merge_experts(model, sim_threshold=0.97) else: print(f"[WARNING] Checkpoint {ckpt_path} not found, using random weights") @@ -510,6 +513,117 @@ def load_model(ckpt_path, device='cuda:0'): return model, dev +def _merge_experts(model, sim_threshold=0.97): + """按权重余弦相似度合并冗余 MoE expert。 + 合规:仅删除冗余部分,不改层数/维度/head/FFN channel。""" + total_merged = 0 + for layer_idx, moe in enumerate(model.seq_encoder.moe): + num_exp = moe.num_experts + if num_exp <= 1: + continue + + # 1) 计算 8×8 成对相似度矩阵(fc1+fc2 平均余弦相似度) + sim_matrix = torch.zeros(num_exp, num_exp) + for i in range(num_exp): + for j in range(i + 1, num_exp): + w_i = torch.cat([ + moe.experts[i].fc1.weight.data.flatten().float(), + moe.experts[i].fc2.weight.data.flatten().float(), + ]) + w_j = torch.cat([ + moe.experts[j].fc1.weight.data.flatten().float(), + moe.experts[j].fc2.weight.data.flatten().float(), + ]) + sim = F.cosine_similarity(w_i.unsqueeze(0), w_j.unsqueeze(0)).item() + sim_matrix[i, j] = sim + sim_matrix[j, i] = sim + + # 2) 贪心聚类:从最高相似度 pair 开始,> threshold 则合并 + parent = list(range(num_exp)) + + def find(x): + while parent[x] != x: + parent[x] = parent[parent[x]] + x = parent[x] + return x + + def union(a, b): + ra, rb = find(a), find(b) + if ra != rb: + parent[rb] = ra + + # 按相似度降序遍历所有 pair + pairs = [(sim_matrix[i, j].item(), i, j) + for i in range(num_exp) for j in range(i + 1, num_exp)] + pairs.sort(reverse=True) + + for sim_val, i, j in pairs: + if sim_val > sim_threshold: + union(i, j) + + # 3) 分组 + clusters = {} + for i in range(num_exp): + root = find(i) + clusters.setdefault(root, []).append(i) + + # 如果所有 cluster 都只有 1 个 expert,跳过 + if all(len(c) == 1 for c in clusters.values()): + continue + + # 4) 合并每个 cluster + new_experts = [] + new_gate_w = [] + new_gate_b = [] + for root, indices in clusters.items(): + if len(indices) == 1: + idx = indices[0] + new_experts.append(moe.experts[idx]) + new_gate_w.append(moe.gate.w_g.weight.data[idx].clone()) + new_gate_b.append(moe.gate.w_g.bias.data[idx].clone()) + else: + # 平均权重 + avg_fc1_w = sum(moe.experts[k].fc1.weight.data for k in indices) / len(indices) + avg_fc1_b = sum(moe.experts[k].fc1.bias.data for k in indices) / len(indices) + avg_fc2_w = sum(moe.experts[k].fc2.weight.data for k in indices) / len(indices) + avg_fc2_b = sum(moe.experts[k].fc2.bias.data for k in indices) / len(indices) + merged = Expert(moe.experts[0].fc1.in_features, + moe.experts[0].fc1.out_features) + merged.fc1.weight.data = avg_fc1_w + merged.fc1.bias.data = avg_fc1_b + merged.fc2.weight.data = avg_fc2_w + merged.fc2.bias.data = avg_fc2_b + new_experts.append(merged) + # 平均 gate 权重 + avg_gate_w = sum(moe.gate.w_g.weight.data[k].clone() for k in indices) / len(indices) + avg_gate_b = sum(moe.gate.w_g.bias.data[k].clone() for k in indices) / len(indices) + new_gate_w.append(avg_gate_w) + new_gate_b.append(avg_gate_b) + total_merged += len(indices) - 1 + + # 5) 更新 MoE 层 + old_num = moe.num_experts + new_num = len(new_experts) + moe.experts = nn.ModuleList(new_experts) + moe.num_experts = new_num + new_k = min(moe.k, new_num) + moe.k = new_k + moe.gate.k = new_k + # 替换 gate weight 和 bias + moe.gate.w_g = nn.Linear(moe.gate.w_g.in_features, new_num).to( + moe.gate.w_g.weight.device) + moe.gate.w_g.weight.data = torch.stack(new_gate_w) + moe.gate.w_g.bias.data = torch.stack(new_gate_b) + moe.gate.num_experts = new_num + print(f" Layer {layer_idx}: {old_num} → {new_num} experts " + f"(merged {old_num - new_num}, k={moe.k})") + + if total_merged > 0: + print(f"[INFO] Total merged experts: {total_merged}") + else: + print("[INFO] No experts merged (all below similarity threshold)") + + # ============================================================ # 打分工具(与 evaluation.py 保持一致) # ============================================================