feat: Expert 权重相似度合并(余弦相似度>0.97 的 expert 合并,减少冗余计算)
- 贪心聚类:并查集按相似度降序合并 - 合并策略:fc1/fc2 权重+bias 取平均,gate 对应行取平均 - k 保护:合并后 expert 数 < k 时自动降 k - 属 Q&A 允许的删除冗余度高操作,不改变层数/维度/head/FFN channel
This commit is contained in:
@@ -500,6 +500,9 @@ def load_model(ckpt_path, device='cuda:0'):
|
|||||||
model = model.half()
|
model = model.half()
|
||||||
model.rep_encoder.emb = model.rep_encoder.emb.to(torch.float32)
|
model.rep_encoder.emb = model.rep_encoder.emb.to(torch.float32)
|
||||||
print("[INFO] Model converted to FP16 (embedding kept in FP32)")
|
print("[INFO] Model converted to FP16 (embedding kept in FP32)")
|
||||||
|
|
||||||
|
# === 按 Expert 权重相似度合并冗余 expert ===
|
||||||
|
_merge_experts(model, sim_threshold=0.97)
|
||||||
else:
|
else:
|
||||||
print(f"[WARNING] Checkpoint {ckpt_path} not found, using random weights")
|
print(f"[WARNING] Checkpoint {ckpt_path} not found, using random weights")
|
||||||
|
|
||||||
@@ -510,6 +513,117 @@ def load_model(ckpt_path, device='cuda:0'):
|
|||||||
return model, dev
|
return model, dev
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_experts(model, sim_threshold=0.97):
|
||||||
|
"""按权重余弦相似度合并冗余 MoE expert。
|
||||||
|
合规:仅删除冗余部分,不改层数/维度/head/FFN channel。"""
|
||||||
|
total_merged = 0
|
||||||
|
for layer_idx, moe in enumerate(model.seq_encoder.moe):
|
||||||
|
num_exp = moe.num_experts
|
||||||
|
if num_exp <= 1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 1) 计算 8×8 成对相似度矩阵(fc1+fc2 平均余弦相似度)
|
||||||
|
sim_matrix = torch.zeros(num_exp, num_exp)
|
||||||
|
for i in range(num_exp):
|
||||||
|
for j in range(i + 1, num_exp):
|
||||||
|
w_i = torch.cat([
|
||||||
|
moe.experts[i].fc1.weight.data.flatten().float(),
|
||||||
|
moe.experts[i].fc2.weight.data.flatten().float(),
|
||||||
|
])
|
||||||
|
w_j = torch.cat([
|
||||||
|
moe.experts[j].fc1.weight.data.flatten().float(),
|
||||||
|
moe.experts[j].fc2.weight.data.flatten().float(),
|
||||||
|
])
|
||||||
|
sim = F.cosine_similarity(w_i.unsqueeze(0), w_j.unsqueeze(0)).item()
|
||||||
|
sim_matrix[i, j] = sim
|
||||||
|
sim_matrix[j, i] = sim
|
||||||
|
|
||||||
|
# 2) 贪心聚类:从最高相似度 pair 开始,> threshold 则合并
|
||||||
|
parent = list(range(num_exp))
|
||||||
|
|
||||||
|
def find(x):
|
||||||
|
while parent[x] != x:
|
||||||
|
parent[x] = parent[parent[x]]
|
||||||
|
x = parent[x]
|
||||||
|
return x
|
||||||
|
|
||||||
|
def union(a, b):
|
||||||
|
ra, rb = find(a), find(b)
|
||||||
|
if ra != rb:
|
||||||
|
parent[rb] = ra
|
||||||
|
|
||||||
|
# 按相似度降序遍历所有 pair
|
||||||
|
pairs = [(sim_matrix[i, j].item(), i, j)
|
||||||
|
for i in range(num_exp) for j in range(i + 1, num_exp)]
|
||||||
|
pairs.sort(reverse=True)
|
||||||
|
|
||||||
|
for sim_val, i, j in pairs:
|
||||||
|
if sim_val > sim_threshold:
|
||||||
|
union(i, j)
|
||||||
|
|
||||||
|
# 3) 分组
|
||||||
|
clusters = {}
|
||||||
|
for i in range(num_exp):
|
||||||
|
root = find(i)
|
||||||
|
clusters.setdefault(root, []).append(i)
|
||||||
|
|
||||||
|
# 如果所有 cluster 都只有 1 个 expert,跳过
|
||||||
|
if all(len(c) == 1 for c in clusters.values()):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 4) 合并每个 cluster
|
||||||
|
new_experts = []
|
||||||
|
new_gate_w = []
|
||||||
|
new_gate_b = []
|
||||||
|
for root, indices in clusters.items():
|
||||||
|
if len(indices) == 1:
|
||||||
|
idx = indices[0]
|
||||||
|
new_experts.append(moe.experts[idx])
|
||||||
|
new_gate_w.append(moe.gate.w_g.weight.data[idx].clone())
|
||||||
|
new_gate_b.append(moe.gate.w_g.bias.data[idx].clone())
|
||||||
|
else:
|
||||||
|
# 平均权重
|
||||||
|
avg_fc1_w = sum(moe.experts[k].fc1.weight.data for k in indices) / len(indices)
|
||||||
|
avg_fc1_b = sum(moe.experts[k].fc1.bias.data for k in indices) / len(indices)
|
||||||
|
avg_fc2_w = sum(moe.experts[k].fc2.weight.data for k in indices) / len(indices)
|
||||||
|
avg_fc2_b = sum(moe.experts[k].fc2.bias.data for k in indices) / len(indices)
|
||||||
|
merged = Expert(moe.experts[0].fc1.in_features,
|
||||||
|
moe.experts[0].fc1.out_features)
|
||||||
|
merged.fc1.weight.data = avg_fc1_w
|
||||||
|
merged.fc1.bias.data = avg_fc1_b
|
||||||
|
merged.fc2.weight.data = avg_fc2_w
|
||||||
|
merged.fc2.bias.data = avg_fc2_b
|
||||||
|
new_experts.append(merged)
|
||||||
|
# 平均 gate 权重
|
||||||
|
avg_gate_w = sum(moe.gate.w_g.weight.data[k].clone() for k in indices) / len(indices)
|
||||||
|
avg_gate_b = sum(moe.gate.w_g.bias.data[k].clone() for k in indices) / len(indices)
|
||||||
|
new_gate_w.append(avg_gate_w)
|
||||||
|
new_gate_b.append(avg_gate_b)
|
||||||
|
total_merged += len(indices) - 1
|
||||||
|
|
||||||
|
# 5) 更新 MoE 层
|
||||||
|
old_num = moe.num_experts
|
||||||
|
new_num = len(new_experts)
|
||||||
|
moe.experts = nn.ModuleList(new_experts)
|
||||||
|
moe.num_experts = new_num
|
||||||
|
new_k = min(moe.k, new_num)
|
||||||
|
moe.k = new_k
|
||||||
|
moe.gate.k = new_k
|
||||||
|
# 替换 gate weight 和 bias
|
||||||
|
moe.gate.w_g = nn.Linear(moe.gate.w_g.in_features, new_num).to(
|
||||||
|
moe.gate.w_g.weight.device)
|
||||||
|
moe.gate.w_g.weight.data = torch.stack(new_gate_w)
|
||||||
|
moe.gate.w_g.bias.data = torch.stack(new_gate_b)
|
||||||
|
moe.gate.num_experts = new_num
|
||||||
|
print(f" Layer {layer_idx}: {old_num} → {new_num} experts "
|
||||||
|
f"(merged {old_num - new_num}, k={moe.k})")
|
||||||
|
|
||||||
|
if total_merged > 0:
|
||||||
|
print(f"[INFO] Total merged experts: {total_merged}")
|
||||||
|
else:
|
||||||
|
print("[INFO] No experts merged (all below similarity threshold)")
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# 打分工具(与 evaluation.py 保持一致)
|
# 打分工具(与 evaluation.py 保持一致)
|
||||||
# ============================================================
|
# ============================================================
|
||||||
|
|||||||
Reference in New Issue
Block a user