From 3e1d5b8e5961a19ab13297ac39f7f68c5ed82e13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E8=88=AA=E5=AE=87?= <3364451258@qq.com> Date: Sun, 14 Jun 2026 11:16:04 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20Expert=20=E6=9D=83=E9=87=8D=E7=9B=B8?= =?UTF-8?q?=E4=BC=BC=E5=BA=A6=E5=90=88=E5=B9=B6=EF=BC=88=E4=BD=99=E5=BC=A6?= =?UTF-8?q?=E7=9B=B8=E4=BC=BC=E5=BA=A6>0.97=20=E7=9A=84=20expert=20?= =?UTF-8?q?=E5=90=88=E5=B9=B6=EF=BC=8C=E5=87=8F=E5=B0=91=E5=86=97=E4=BD=99?= =?UTF-8?q?=E8=AE=A1=E7=AE=97=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 贪心聚类:并查集按相似度降序合并 - 合并策略:fc1/fc2 权重+bias 取平均,gate 对应行取平均 - k 保护:合并后 expert 数 < k 时自动降 k - 属 Q&A 允许的删除冗余度高操作,不改变层数/维度/head/FFN channel --- 代码/code/infer.py | 114 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) diff --git a/代码/code/infer.py b/代码/code/infer.py index 9bca191..6df0dd4 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -500,6 +500,9 @@ def load_model(ckpt_path, device='cuda:0'): model = model.half() model.rep_encoder.emb = model.rep_encoder.emb.to(torch.float32) print("[INFO] Model converted to FP16 (embedding kept in FP32)") + + # === 按 Expert 权重相似度合并冗余 expert === + _merge_experts(model, sim_threshold=0.97) else: print(f"[WARNING] Checkpoint {ckpt_path} not found, using random weights") @@ -510,6 +513,117 @@ def load_model(ckpt_path, device='cuda:0'): return model, dev +def _merge_experts(model, sim_threshold=0.97): + """按权重余弦相似度合并冗余 MoE expert。 + 合规:仅删除冗余部分,不改层数/维度/head/FFN channel。""" + total_merged = 0 + for layer_idx, moe in enumerate(model.seq_encoder.moe): + num_exp = moe.num_experts + if num_exp <= 1: + continue + + # 1) 计算 8×8 成对相似度矩阵(fc1+fc2 平均余弦相似度) + sim_matrix = torch.zeros(num_exp, num_exp) + for i in range(num_exp): + for j in range(i + 1, num_exp): + w_i = torch.cat([ + moe.experts[i].fc1.weight.data.flatten().float(), + moe.experts[i].fc2.weight.data.flatten().float(), + ]) + w_j = torch.cat([ + moe.experts[j].fc1.weight.data.flatten().float(), + moe.experts[j].fc2.weight.data.flatten().float(), + ]) + sim = F.cosine_similarity(w_i.unsqueeze(0), w_j.unsqueeze(0)).item() + sim_matrix[i, j] = sim + sim_matrix[j, i] = sim + + # 2) 贪心聚类:从最高相似度 pair 开始,> threshold 则合并 + parent = list(range(num_exp)) + + def find(x): + while parent[x] != x: + parent[x] = parent[parent[x]] + x = parent[x] + return x + + def union(a, b): + ra, rb = find(a), find(b) + if ra != rb: + parent[rb] = ra + + # 按相似度降序遍历所有 pair + pairs = [(sim_matrix[i, j].item(), i, j) + for i in range(num_exp) for j in range(i + 1, num_exp)] + pairs.sort(reverse=True) + + for sim_val, i, j in pairs: + if sim_val > sim_threshold: + union(i, j) + + # 3) 分组 + clusters = {} + for i in range(num_exp): + root = find(i) + clusters.setdefault(root, []).append(i) + + # 如果所有 cluster 都只有 1 个 expert,跳过 + if all(len(c) == 1 for c in clusters.values()): + continue + + # 4) 合并每个 cluster + new_experts = [] + new_gate_w = [] + new_gate_b = [] + for root, indices in clusters.items(): + if len(indices) == 1: + idx = indices[0] + new_experts.append(moe.experts[idx]) + new_gate_w.append(moe.gate.w_g.weight.data[idx].clone()) + new_gate_b.append(moe.gate.w_g.bias.data[idx].clone()) + else: + # 平均权重 + avg_fc1_w = sum(moe.experts[k].fc1.weight.data for k in indices) / len(indices) + avg_fc1_b = sum(moe.experts[k].fc1.bias.data for k in indices) / len(indices) + avg_fc2_w = sum(moe.experts[k].fc2.weight.data for k in indices) / len(indices) + avg_fc2_b = sum(moe.experts[k].fc2.bias.data for k in indices) / len(indices) + merged = Expert(moe.experts[0].fc1.in_features, + moe.experts[0].fc1.out_features) + merged.fc1.weight.data = avg_fc1_w + merged.fc1.bias.data = avg_fc1_b + merged.fc2.weight.data = avg_fc2_w + merged.fc2.bias.data = avg_fc2_b + new_experts.append(merged) + # 平均 gate 权重 + avg_gate_w = sum(moe.gate.w_g.weight.data[k].clone() for k in indices) / len(indices) + avg_gate_b = sum(moe.gate.w_g.bias.data[k].clone() for k in indices) / len(indices) + new_gate_w.append(avg_gate_w) + new_gate_b.append(avg_gate_b) + total_merged += len(indices) - 1 + + # 5) 更新 MoE 层 + old_num = moe.num_experts + new_num = len(new_experts) + moe.experts = nn.ModuleList(new_experts) + moe.num_experts = new_num + new_k = min(moe.k, new_num) + moe.k = new_k + moe.gate.k = new_k + # 替换 gate weight 和 bias + moe.gate.w_g = nn.Linear(moe.gate.w_g.in_features, new_num).to( + moe.gate.w_g.weight.device) + moe.gate.w_g.weight.data = torch.stack(new_gate_w) + moe.gate.w_g.bias.data = torch.stack(new_gate_b) + moe.gate.num_experts = new_num + print(f" Layer {layer_idx}: {old_num} → {new_num} experts " + f"(merged {old_num - new_num}, k={moe.k})") + + if total_merged > 0: + print(f"[INFO] Total merged experts: {total_merged}") + else: + print("[INFO] No experts merged (all below similarity threshold)") + + # ============================================================ # 打分工具(与 evaluation.py 保持一致) # ============================================================