From 2268fa6cf30433c505c5e00a48c9a0405fe42e8b Mon Sep 17 00:00:00 2001 From: OwnerSunshine530 Date: Mon, 15 Jun 2026 14:07:23 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20dedup=5Fembedding=20=E9=80=89=E9=A1=B9?= =?UTF-8?q?=20=E2=80=94=20=E6=9F=A5=E8=A1=A8=E5=89=8D=E5=AF=B9sign?= =?UTF-8?q?=E5=8E=BB=E9=87=8D(slot19=E7=AD=89=E9=AB=98=E9=87=8D=E5=A4=8D),?= =?UTF-8?q?=E5=87=8F=E5=B0=91=E5=A4=A7=E8=A1=A8=E9=9A=8F=E6=9C=BA=E8=AE=BF?= =?UTF-8?q?=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit profile显示embedding查表现为头号瓶颈(32%)。torch.unique去重后只查唯一sign 再按逆索引展开,数学逐位等价(AUC不变),省最贵的大表随机gather。bench --dedup-emb。 Co-Authored-By: Claude Opus 4.8 --- 代码/code/bench.py | 3 +++ 代码/code/infer.py | 8 +++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/代码/code/bench.py b/代码/code/bench.py index c0c4f67..ea890b1 100644 --- a/代码/code/bench.py +++ b/代码/code/bench.py @@ -298,6 +298,7 @@ def _parse_args(): help="MoE实现:dense=向量化(新), loop=逐expert循环(原)") ap.add_argument("--compile", action="store_true", help="开启 torch.compile") ap.add_argument("--emb-fp16", action="store_true", help="Embedding表转FP16(查表带宽减半,测AUC)") + ap.add_argument("--dedup-emb", action="store_true", help="查表前对sign去重(减少大表随机访存)") ap.add_argument("--profile", type=int, default=None, metavar="N", help="剖析前 N 个 batch,打印按 CUDA 耗时排序的算子表(定位瓶颈)") ap.add_argument("--rebuild", action="store_true", help="强制重建过滤缓存") @@ -331,6 +332,8 @@ if __name__ == "__main__": cfg["vectorize_moe"] = (a.moe == "dense") if a.emb_fp16: cfg["emb_fp16"] = True + if a.dedup_emb: + cfg["dedup_embedding"] = True if a.compile: cfg["compile"] = True if a.profile is not None: diff --git a/代码/code/infer.py b/代码/code/infer.py index 8397e07..ac078e6 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -52,6 +52,7 @@ CONFIG = { "fuse_embedding": True, # True=28个slot的查表+池化融合为1次(减per-batch kernel启动) "syncfree_mask": True, # True=用searchsorted构造因果mask(无同步);False=repeat_interleave(同步) "emb_fp16": True, # True=Embedding表转FP16(查表带宽减半,实测AUC 0.75932≈无损) + "dedup_embedding": False, # True=查表前对sign去重(只查唯一值再展开),减少大表随机访存。数学等价 "compile": False, # 是否 torch.compile(实测慢5×,勿开) } @@ -380,7 +381,12 @@ class RepEncoder(nn.Module): cat_values = self._signid(torch.cat(parts), max_idx) seg = torch.cat([torch.zeros(1, dtype=torch.long, device=cat_values.device), torch.cat(ends)]) # [28*N + 1] - emb = self.emb(cat_values).to(target_dtype) + if CONFIG.get("dedup_embedding", False): + # 去重:只对唯一 sign 查大表,再按逆索引展开(数学逐位等价,省随机访存) + uniq, inv = torch.unique(cat_values, return_inverse=True) + emb = self.emb(uniq).to(target_dtype)[inv] + else: + emb = self.emb(cat_values).to(target_dtype) pooled = torch.segment_reduce(emb, reduce='sum', offsets=seg, initial=0) # [28*N, emb] pooled = pooled.view(self.slot_num, N, self.emb_dim).permute(1, 0, 2).reshape( N, self.slot_num * self.emb_dim)