From 7e0876c671ae21d4eded5e381ce7a150a8279ead Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E8=88=AA=E5=AE=87?= <3364451258@qq.com> Date: Sat, 13 Jun 2026 13:05:14 +0800 Subject: [PATCH] =?UTF-8?q?revert:=20RepEncoder=20=E6=89=B9=E9=87=8F=20emb?= =?UTF-8?q?edding=20=E6=9F=A5=E8=A1=A8=EF=BC=8894.3s=20vs=2092.5s=EF=BC=8C?= =?UTF-8?q?=E7=95=A5=E6=85=A2=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 回退到稳定版:FP16 + Flash Attention + inference_mode(57.45 分) --- 代码/code/infer.py | 25 ++++--------------------- 1 file changed, 4 insertions(+), 21 deletions(-) diff --git a/代码/code/infer.py b/代码/code/infer.py index dc51204..d2109c5 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -257,33 +257,16 @@ class RepEncoder(nn.Module): self.linear = nn.Linear(in_features=slot_num * emb_dim, out_features=d_model) def forward(self, batch): + pooled_embs = [] max_idx = self.emb.num_embeddings - 1 target_dtype = self.input_norm.weight.dtype # 后续层 dtype(FP16 时为 torch.float16) - - # 批量收集所有 slot 的 values,一次 embedding 查表(减少 28 → 1 次 kernel launch) - all_values = [] - all_offsets = [] - slot_boundaries = [0] # 记录每个 slot 在 all_values 中的起止位置 for i in range(self.slot_num): values, offsets = batch[i + 1] offsets = offsets.to(values.device) - values = values.clamp(0, max_idx) - all_values.append(values) - all_offsets.append(offsets) - slot_boundaries.append(slot_boundaries[-1] + values.size(0)) - - # 一次批量 embedding 查表 - values_cat = torch.cat(all_values) - embs_cat = self.emb(values_cat).to(target_dtype) - - # 按 slot 拆分并 segment_reduce - pooled_embs = [] - for i in range(self.slot_num): - start, end = slot_boundaries[i], slot_boundaries[i + 1] - slot_embs = embs_cat[start:end] - res = torch.segment_reduce(slot_embs, reduce='sum', offsets=all_offsets[i], initial=0) + values = values.clamp(0, max_idx) # 超出 vocab_size 的 sign id 截断,避免越界 + sign_emb = self.emb(values).to(target_dtype) + res = torch.segment_reduce(sign_emb, reduce='sum', offsets=offsets, initial=0) pooled_embs.append(res) - fused_embs = torch.cat(pooled_embs, dim=1) norm_emb = self.input_norm(fused_embs) rep_emb = self.linear(norm_emb)