diff --git a/代码/code/infer.py b/代码/code/infer.py index 543b27b..74ba35e 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -259,11 +259,12 @@ class RepEncoder(nn.Module): def forward(self, batch): pooled_embs = [] max_idx = self.emb.num_embeddings - 1 + target_dtype = self.input_norm.weight.dtype # 后续层 dtype(FP16 时为 torch.float16) for i in range(self.slot_num): values, offsets = batch[i + 1] offsets = offsets.to(values.device) values = values.clamp(0, max_idx) # 超出 vocab_size 的 sign id 截断,避免越界 - sign_emb = self.emb(values) + sign_emb = self.emb(values).to(target_dtype) res = torch.segment_reduce(sign_emb, reduce='sum', offsets=offsets, initial=0) pooled_embs.append(res) fused_embs = torch.cat(pooled_embs, dim=1)