revert: RepEncoder 批量 embedding 查表(94.3s vs 92.5s,略慢)
回退到稳定版:FP16 + Flash Attention + inference_mode(57.45 分)
This commit is contained in:
+4
-21
@@ -257,33 +257,16 @@ class RepEncoder(nn.Module):
|
|||||||
self.linear = nn.Linear(in_features=slot_num * emb_dim, out_features=d_model)
|
self.linear = nn.Linear(in_features=slot_num * emb_dim, out_features=d_model)
|
||||||
|
|
||||||
def forward(self, batch):
|
def forward(self, batch):
|
||||||
|
pooled_embs = []
|
||||||
max_idx = self.emb.num_embeddings - 1
|
max_idx = self.emb.num_embeddings - 1
|
||||||
target_dtype = self.input_norm.weight.dtype # 后续层 dtype(FP16 时为 torch.float16)
|
target_dtype = self.input_norm.weight.dtype # 后续层 dtype(FP16 时为 torch.float16)
|
||||||
|
|
||||||
# 批量收集所有 slot 的 values,一次 embedding 查表(减少 28 → 1 次 kernel launch)
|
|
||||||
all_values = []
|
|
||||||
all_offsets = []
|
|
||||||
slot_boundaries = [0] # 记录每个 slot 在 all_values 中的起止位置
|
|
||||||
for i in range(self.slot_num):
|
for i in range(self.slot_num):
|
||||||
values, offsets = batch[i + 1]
|
values, offsets = batch[i + 1]
|
||||||
offsets = offsets.to(values.device)
|
offsets = offsets.to(values.device)
|
||||||
values = values.clamp(0, max_idx)
|
values = values.clamp(0, max_idx) # 超出 vocab_size 的 sign id 截断,避免越界
|
||||||
all_values.append(values)
|
sign_emb = self.emb(values).to(target_dtype)
|
||||||
all_offsets.append(offsets)
|
res = torch.segment_reduce(sign_emb, reduce='sum', offsets=offsets, initial=0)
|
||||||
slot_boundaries.append(slot_boundaries[-1] + values.size(0))
|
|
||||||
|
|
||||||
# 一次批量 embedding 查表
|
|
||||||
values_cat = torch.cat(all_values)
|
|
||||||
embs_cat = self.emb(values_cat).to(target_dtype)
|
|
||||||
|
|
||||||
# 按 slot 拆分并 segment_reduce
|
|
||||||
pooled_embs = []
|
|
||||||
for i in range(self.slot_num):
|
|
||||||
start, end = slot_boundaries[i], slot_boundaries[i + 1]
|
|
||||||
slot_embs = embs_cat[start:end]
|
|
||||||
res = torch.segment_reduce(slot_embs, reduce='sum', offsets=all_offsets[i], initial=0)
|
|
||||||
pooled_embs.append(res)
|
pooled_embs.append(res)
|
||||||
|
|
||||||
fused_embs = torch.cat(pooled_embs, dim=1)
|
fused_embs = torch.cat(pooled_embs, dim=1)
|
||||||
norm_emb = self.input_norm(fused_embs)
|
norm_emb = self.input_norm(fused_embs)
|
||||||
rep_emb = self.linear(norm_emb)
|
rep_emb = self.linear(norm_emb)
|
||||||
|
|||||||
Reference in New Issue
Block a user