fix: RepEncoder forward 中 Embedding FP32 输出显式转为后续层 dtype
修复 FP16 量化后 dtype 不匹配:Embedding 保留 FP32 时,forward 输出需 .to(target_dtype) 对齐后续 LayerNorm/Linear
This commit is contained in:
+2
-1
@@ -259,11 +259,12 @@ class RepEncoder(nn.Module):
|
||||
def forward(self, batch):
|
||||
pooled_embs = []
|
||||
max_idx = self.emb.num_embeddings - 1
|
||||
target_dtype = self.input_norm.weight.dtype # 后续层 dtype(FP16 时为 torch.float16)
|
||||
for i in range(self.slot_num):
|
||||
values, offsets = batch[i + 1]
|
||||
offsets = offsets.to(values.device)
|
||||
values = values.clamp(0, max_idx) # 超出 vocab_size 的 sign id 截断,避免越界
|
||||
sign_emb = self.emb(values)
|
||||
sign_emb = self.emb(values).to(target_dtype)
|
||||
res = torch.segment_reduce(sign_emb, reduce='sum', offsets=offsets, initial=0)
|
||||
pooled_embs.append(res)
|
||||
fused_embs = torch.cat(pooled_embs, dim=1)
|
||||
|
||||
Reference in New Issue
Block a user