feat: Expert 权重相似度合并(余弦相似度>0.97 的 expert 合并,减少冗余计算)
- 贪心聚类:并查集按相似度降序合并 - 合并策略:fc1/fc2 权重+bias 取平均,gate 对应行取平均 - k 保护:合并后 expert 数 < k 时自动降 k - 属 Q&A 允许的删除冗余度高操作,不改变层数/维度/head/FFN channel
This commit is contained in:
@@ -500,6 +500,9 @@ def load_model(ckpt_path, device='cuda:0'):
|
||||
model = model.half()
|
||||
model.rep_encoder.emb = model.rep_encoder.emb.to(torch.float32)
|
||||
print("[INFO] Model converted to FP16 (embedding kept in FP32)")
|
||||
|
||||
# === 按 Expert 权重相似度合并冗余 expert ===
|
||||
_merge_experts(model, sim_threshold=0.97)
|
||||
else:
|
||||
print(f"[WARNING] Checkpoint {ckpt_path} not found, using random weights")
|
||||
|
||||
@@ -510,6 +513,117 @@ def load_model(ckpt_path, device='cuda:0'):
|
||||
return model, dev
|
||||
|
||||
|
||||
def _merge_experts(model, sim_threshold=0.97):
|
||||
"""按权重余弦相似度合并冗余 MoE expert。
|
||||
合规:仅删除冗余部分,不改层数/维度/head/FFN channel。"""
|
||||
total_merged = 0
|
||||
for layer_idx, moe in enumerate(model.seq_encoder.moe):
|
||||
num_exp = moe.num_experts
|
||||
if num_exp <= 1:
|
||||
continue
|
||||
|
||||
# 1) 计算 8×8 成对相似度矩阵(fc1+fc2 平均余弦相似度)
|
||||
sim_matrix = torch.zeros(num_exp, num_exp)
|
||||
for i in range(num_exp):
|
||||
for j in range(i + 1, num_exp):
|
||||
w_i = torch.cat([
|
||||
moe.experts[i].fc1.weight.data.flatten().float(),
|
||||
moe.experts[i].fc2.weight.data.flatten().float(),
|
||||
])
|
||||
w_j = torch.cat([
|
||||
moe.experts[j].fc1.weight.data.flatten().float(),
|
||||
moe.experts[j].fc2.weight.data.flatten().float(),
|
||||
])
|
||||
sim = F.cosine_similarity(w_i.unsqueeze(0), w_j.unsqueeze(0)).item()
|
||||
sim_matrix[i, j] = sim
|
||||
sim_matrix[j, i] = sim
|
||||
|
||||
# 2) 贪心聚类:从最高相似度 pair 开始,> threshold 则合并
|
||||
parent = list(range(num_exp))
|
||||
|
||||
def find(x):
|
||||
while parent[x] != x:
|
||||
parent[x] = parent[parent[x]]
|
||||
x = parent[x]
|
||||
return x
|
||||
|
||||
def union(a, b):
|
||||
ra, rb = find(a), find(b)
|
||||
if ra != rb:
|
||||
parent[rb] = ra
|
||||
|
||||
# 按相似度降序遍历所有 pair
|
||||
pairs = [(sim_matrix[i, j].item(), i, j)
|
||||
for i in range(num_exp) for j in range(i + 1, num_exp)]
|
||||
pairs.sort(reverse=True)
|
||||
|
||||
for sim_val, i, j in pairs:
|
||||
if sim_val > sim_threshold:
|
||||
union(i, j)
|
||||
|
||||
# 3) 分组
|
||||
clusters = {}
|
||||
for i in range(num_exp):
|
||||
root = find(i)
|
||||
clusters.setdefault(root, []).append(i)
|
||||
|
||||
# 如果所有 cluster 都只有 1 个 expert,跳过
|
||||
if all(len(c) == 1 for c in clusters.values()):
|
||||
continue
|
||||
|
||||
# 4) 合并每个 cluster
|
||||
new_experts = []
|
||||
new_gate_w = []
|
||||
new_gate_b = []
|
||||
for root, indices in clusters.items():
|
||||
if len(indices) == 1:
|
||||
idx = indices[0]
|
||||
new_experts.append(moe.experts[idx])
|
||||
new_gate_w.append(moe.gate.w_g.weight.data[idx].clone())
|
||||
new_gate_b.append(moe.gate.w_g.bias.data[idx].clone())
|
||||
else:
|
||||
# 平均权重
|
||||
avg_fc1_w = sum(moe.experts[k].fc1.weight.data for k in indices) / len(indices)
|
||||
avg_fc1_b = sum(moe.experts[k].fc1.bias.data for k in indices) / len(indices)
|
||||
avg_fc2_w = sum(moe.experts[k].fc2.weight.data for k in indices) / len(indices)
|
||||
avg_fc2_b = sum(moe.experts[k].fc2.bias.data for k in indices) / len(indices)
|
||||
merged = Expert(moe.experts[0].fc1.in_features,
|
||||
moe.experts[0].fc1.out_features)
|
||||
merged.fc1.weight.data = avg_fc1_w
|
||||
merged.fc1.bias.data = avg_fc1_b
|
||||
merged.fc2.weight.data = avg_fc2_w
|
||||
merged.fc2.bias.data = avg_fc2_b
|
||||
new_experts.append(merged)
|
||||
# 平均 gate 权重
|
||||
avg_gate_w = sum(moe.gate.w_g.weight.data[k].clone() for k in indices) / len(indices)
|
||||
avg_gate_b = sum(moe.gate.w_g.bias.data[k].clone() for k in indices) / len(indices)
|
||||
new_gate_w.append(avg_gate_w)
|
||||
new_gate_b.append(avg_gate_b)
|
||||
total_merged += len(indices) - 1
|
||||
|
||||
# 5) 更新 MoE 层
|
||||
old_num = moe.num_experts
|
||||
new_num = len(new_experts)
|
||||
moe.experts = nn.ModuleList(new_experts)
|
||||
moe.num_experts = new_num
|
||||
new_k = min(moe.k, new_num)
|
||||
moe.k = new_k
|
||||
moe.gate.k = new_k
|
||||
# 替换 gate weight 和 bias
|
||||
moe.gate.w_g = nn.Linear(moe.gate.w_g.in_features, new_num).to(
|
||||
moe.gate.w_g.weight.device)
|
||||
moe.gate.w_g.weight.data = torch.stack(new_gate_w)
|
||||
moe.gate.w_g.bias.data = torch.stack(new_gate_b)
|
||||
moe.gate.num_experts = new_num
|
||||
print(f" Layer {layer_idx}: {old_num} → {new_num} experts "
|
||||
f"(merged {old_num - new_num}, k={moe.k})")
|
||||
|
||||
if total_merged > 0:
|
||||
print(f"[INFO] Total merged experts: {total_merged}")
|
||||
else:
|
||||
print("[INFO] No experts merged (all below similarity threshold)")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 打分工具(与 evaluation.py 保持一致)
|
||||
# ============================================================
|
||||
|
||||
Reference in New Issue
Block a user