From 788ca96d503fe4adfdf4218f8fc528b44639e33b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E8=88=AA=E5=AE=87?= <3364451258@qq.com> Date: Sat, 13 Jun 2026 14:05:19 +0800 Subject: [PATCH] =?UTF-8?q?revert:=20=E7=A7=BB=E9=99=A4=20INT8=20=E9=87=8F?= =?UTF-8?q?=E5=8C=96=E5=92=8C=20k=3D1=20=E8=A1=A5=E5=81=BF=EF=BC=8C?= =?UTF-8?q?=E5=9B=9E=E5=88=B0=E7=A8=B3=E5=AE=9A=E7=89=88=2058.49?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- 代码/code/infer.py | 51 ++++++++++++++++------------------------------ 1 file changed, 17 insertions(+), 34 deletions(-) diff --git a/代码/code/infer.py b/代码/code/infer.py index 3b13cd5..9bca191 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -320,7 +320,7 @@ class TopKGate(nn.Module): return topk_idx, topk_score, probs class SMoE(nn.Module): - def __init__(self, d_model, dim_ff, num_experts, k=1): + def __init__(self, d_model, dim_ff, num_experts, k=2): super().__init__() self.num_experts = num_experts self.k = k @@ -329,34 +329,34 @@ class SMoE(nn.Module): Expert(d_model, dim_ff) for _ in range(num_experts) ]) - self.gate = TopKGate(d_model, num_experts, k=2) # gate 内部用 k=2 获取补偿权重 + self.gate = TopKGate(d_model, num_experts, k=k) def forward(self, x): # x: [B,S,D] B, S, D = x.shape topk_idx, topk_score, probs = self.gate(x) - # topk_idx: [B, S, 2], topk_score: [B, S, 2] - - # 仅路由到 Top-1 expert,但用 (p1+p2) 作为权重补偿 - route_idx = topk_idx[:, :, :1] # [B, S, 1] — 只取 top-1 - weight_sum = topk_score.sum(dim=-1) # [B, S] — p1 + p2 作为总权重 out = torch.zeros_like(x) - x_flat = x.reshape(-1, D) - idx_flat = route_idx.reshape(-1) # [B*S] - weight_flat = weight_sum.reshape(-1) # [B*S] - out_flat = out.reshape(-1, D) + # flatten + x_flat = x.reshape(-1, D) # [B*S, D] + idx_flat = topk_idx.reshape(-1, self.k) # [B*S, k] + score_flat = topk_score.reshape(-1, self.k) + out_flat = out.reshape(-1, D) # 提前 reshape,避免循环内重复 for i in range(self.num_experts): - mask = (idx_flat == i) # [B*S] - token_idx = mask.nonzero(as_tuple=True)[0] + # 找到被路由到 expert i 的 token + mask = (idx_flat == i) # [B*S, k] + + token_idx, k_idx = mask.nonzero(as_tuple=True) if token_idx.numel() == 0: continue - selected_x = x_flat[token_idx] - expert_out = self.experts[i](selected_x) - out_flat[token_idx] = expert_out * weight_flat[token_idx].unsqueeze(-1) + + selected_x = x_flat[token_idx] # [N, D] + expert_out = self.experts[i](selected_x) # [N, D] + weight = score_flat[token_idx, k_idx].unsqueeze(-1) + out_flat[token_idx] += expert_out * weight importance = probs.sum(dim=(0,1)) # [E] moe_loss = (importance.std() / (importance.mean() + 1e-6)) @@ -383,7 +383,7 @@ class TransformerEncoder(nn.Module): self.act = getattr(F, act) self.attention_fn = attention_fn self.moe = nn.ModuleList([ - SMoE(d_model, dim_ff, num_experts=8, k=1) # Top-1 路由 + (p1+p2) 权重补偿 + SMoE(d_model, dim_ff, num_experts=8, k=2) for _ in range(num_layers) ]) @@ -500,23 +500,6 @@ def load_model(ckpt_path, device='cuda:0'): model = model.half() model.rep_encoder.emb = model.rep_encoder.emb.to(torch.float32) print("[INFO] Model converted to FP16 (embedding kept in FP32)") - - # === INT8 动态量化:所有 Linear 层权重 INT8,matmul 2x 加速 === - try: - from torch.ao.quantization import quantize_dynamic - # 排除 embedding 层,仅量化 Linear - model.seq_encoder = quantize_dynamic( - model.seq_encoder, {nn.Linear}, dtype=torch.qint8 - ) - model.linear = quantize_dynamic( - model.linear, {nn.Linear}, dtype=torch.qint8 - ) - model.rep_encoder.linear = quantize_dynamic( - model.rep_encoder.linear, {nn.Linear}, dtype=torch.qint8 - ) - print("[INFO] INT8 dynamic quantization applied to Linear layers") - except Exception as e: - print(f"[WARNING] INT8 quantization failed ({e}), keeping FP16") else: print(f"[WARNING] Checkpoint {ckpt_path} not found, using random weights")