perf: RepEncoder 融合 28-slot 查表+池化为单次(减per-batch kernel启动,无新增同步)

延续 dense MoE 的胜因(消 per-batch 开销在评测端被放大见效)。28次embedding
+28次segment_reduce 融合为1次;用 numel 读shape避免同步;base累加无同步。
保留 _rep_forward_perslot 作等价对照。CONFIG.fuse_embedding 默认 True。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-15 11:50:11 +08:00
parent 48f9003a1e
commit 928de22a9b
2 changed files with 66 additions and 13 deletions
+41 -13
View File
@@ -48,6 +48,7 @@ CONFIG = {
# synchronize,去掉同步点可能让被计时的 model(batch) 大幅缩短。本地force-sync看不出,
# 须靠提交验证。AUC中性、MoE仅占2%算力故风险极低。
"vectorize_moe": True, # True=稠密向量化MoE(无同步点)False=原逐expert循环(.nonzero同步)
"fuse_embedding": True, # True=28个slot的查表+池化融合为1次(减per-batch kernel启动)
"compile": False, # 是否 torch.compile(实测慢5×,勿开)
}
@@ -327,6 +328,22 @@ def move_batch_to_device(batch, device):
return batch
def _rep_forward_perslot(enc, batch):
"""原始逐 slot 实现(保留作数值等价对照/回退)。"""
pooled_embs = []
max_idx = enc.emb.num_embeddings - 1
target_dtype = enc.input_norm.weight.dtype
for i in range(enc.slot_num):
values, offsets = batch[i + 1]
offsets = offsets.to(values.device)
values = enc._signid(values, max_idx)
sign_emb = enc.emb(values).to(target_dtype)
res = torch.segment_reduce(sign_emb, reduce='sum', offsets=offsets, initial=0)
pooled_embs.append(res)
fused_embs = torch.cat(pooled_embs, dim=1)
return enc.linear(enc.input_norm(fused_embs))
class RepEncoder(nn.Module):
def __init__(self, vocab_size, emb_dim, padding_idx=0, slot_num=0, d_model=0):
super().__init__()
@@ -336,24 +353,35 @@ class RepEncoder(nn.Module):
self.input_norm = nn.LayerNorm(slot_num * emb_dim)
self.linear = nn.Linear(in_features=slot_num * emb_dim, out_features=d_model)
def _signid(self, values, max_idx):
if CONFIG["signid_mode"] == "modulo":
return values % self.emb.num_embeddings # 取模哈希(与训练一致时用)
return values.clamp(0, max_idx) # 超界 sign id 截断
def forward(self, batch):
pooled_embs = []
if not CONFIG.get("fuse_embedding", True):
return _rep_forward_perslot(self, batch)
max_idx = self.emb.num_embeddings - 1
target_dtype = self.input_norm.weight.dtype # 后续层 dtypeFP16 时为 torch.float16
target_dtype = self.input_norm.weight.dtype
N = batch[1][1].numel() - 1 # 样本数(slot1 的 offsets 段数)
# 把 28 个 slot 的 values 拼成一条,offsets 平移拼成覆盖 28*N 段的单一 offsets
parts, ends, base = [], [], 0
for i in range(self.slot_num):
values, offsets = batch[i + 1]
offsets = offsets.to(values.device)
if CONFIG["signid_mode"] == "modulo":
values = values % self.emb.num_embeddings # 取模哈希(与训练一致时用
else:
values = values.clamp(0, max_idx) # 超出 vocab_size 的 sign id 截断,避免越界
sign_emb = self.emb(values).to(target_dtype)
res = torch.segment_reduce(sign_emb, reduce='sum', offsets=offsets, initial=0)
pooled_embs.append(res)
fused_embs = torch.cat(pooled_embs, dim=1)
norm_emb = self.input_norm(fused_embs)
rep_emb = self.linear(norm_emb)
return rep_emb
parts.append(values)
ends.append(offsets[1:] + base) # 该 slot 各样本的段尾(平移 base
base += values.numel() # numel 读 shape,不触发同步
cat_values = self._signid(torch.cat(parts), max_idx)
seg = torch.cat([torch.zeros(1, dtype=torch.long, device=cat_values.device),
torch.cat(ends)]) # [28*N + 1]
emb = self.emb(cat_values).to(target_dtype)
pooled = torch.segment_reduce(emb, reduce='sum', offsets=seg, initial=0) # [28*N, emb]
pooled = pooled.view(self.slot_num, N, self.emb_dim).permute(1, 0, 2).reshape(
N, self.slot_num * self.emb_dim)
return self.linear(self.input_norm(pooled))
def _varlen_attention(q, k, v, user_offsets):