From 47545efd43eb22e31fd2e441cd3241e6f4ae762d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E8=88=AA=E5=AE=87?= <3364451258@qq.com> Date: Fri, 12 Jun 2026 21:22:06 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20RepEncoder=20forward=20=E4=B8=AD=20Embed?= =?UTF-8?q?ding=20FP32=20=E8=BE=93=E5=87=BA=E6=98=BE=E5=BC=8F=E8=BD=AC?= =?UTF-8?q?=E4=B8=BA=E5=90=8E=E7=BB=AD=E5=B1=82=20dtype?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复 FP16 量化后 dtype 不匹配:Embedding 保留 FP32 时,forward 输出需 .to(target_dtype) 对齐后续 LayerNorm/Linear --- 代码/code/infer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/代码/code/infer.py b/代码/code/infer.py index 543b27b..74ba35e 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -259,11 +259,12 @@ class RepEncoder(nn.Module): def forward(self, batch): pooled_embs = [] max_idx = self.emb.num_embeddings - 1 + target_dtype = self.input_norm.weight.dtype # 后续层 dtype(FP16 时为 torch.float16) for i in range(self.slot_num): values, offsets = batch[i + 1] offsets = offsets.to(values.device) values = values.clamp(0, max_idx) # 超出 vocab_size 的 sign id 截断,避免越界 - sign_emb = self.emb(values) + sign_emb = self.emb(values).to(target_dtype) res = torch.segment_reduce(sign_emb, reduce='sum', offsets=offsets, initial=0) pooled_embs.append(res) fused_embs = torch.cat(pooled_embs, dim=1)