From 9128b60e9d57c9f07671214d86cd7cb38f1667f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E8=88=AA=E5=AE=87?= <3364451258@qq.com> Date: Sat, 13 Jun 2026 12:36:25 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20RepEncoder=20=E6=89=B9=E9=87=8F=20embed?= =?UTF-8?q?ding=20=E6=9F=A5=E8=A1=A8=EF=BC=8828=20=E6=AC=A1=20kernel=20lau?= =?UTF-8?q?nch=20=E2=86=92=201=20=E6=AC=A1=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 所有 slot 的 sign id 合并为一次 embedding lookup,再按 slot 拆分做 segment_reduce。 数学等价,纯 GPU 算子优化。 --- 代码/code/infer.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/代码/code/infer.py b/代码/code/infer.py index d2109c5..dc51204 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -257,16 +257,33 @@ class RepEncoder(nn.Module): self.linear = nn.Linear(in_features=slot_num * emb_dim, out_features=d_model) def forward(self, batch): - pooled_embs = [] max_idx = self.emb.num_embeddings - 1 target_dtype = self.input_norm.weight.dtype # 后续层 dtype(FP16 时为 torch.float16) + + # 批量收集所有 slot 的 values,一次 embedding 查表(减少 28 → 1 次 kernel launch) + all_values = [] + all_offsets = [] + slot_boundaries = [0] # 记录每个 slot 在 all_values 中的起止位置 for i in range(self.slot_num): values, offsets = batch[i + 1] offsets = offsets.to(values.device) - values = values.clamp(0, max_idx) # 超出 vocab_size 的 sign id 截断,避免越界 - sign_emb = self.emb(values).to(target_dtype) - res = torch.segment_reduce(sign_emb, reduce='sum', offsets=offsets, initial=0) + values = values.clamp(0, max_idx) + all_values.append(values) + all_offsets.append(offsets) + slot_boundaries.append(slot_boundaries[-1] + values.size(0)) + + # 一次批量 embedding 查表 + values_cat = torch.cat(all_values) + embs_cat = self.emb(values_cat).to(target_dtype) + + # 按 slot 拆分并 segment_reduce + pooled_embs = [] + for i in range(self.slot_num): + start, end = slot_boundaries[i], slot_boundaries[i + 1] + slot_embs = embs_cat[start:end] + res = torch.segment_reduce(slot_embs, reduce='sum', offsets=all_offsets[i], initial=0) pooled_embs.append(res) + fused_embs = torch.cat(pooled_embs, dim=1) norm_emb = self.input_norm(fused_embs) rep_emb = self.linear(norm_emb)