feat: dedup_embedding 选项 — 查表前对sign去重(slot19等高重复),减少大表随机访存

profile显示embedding查表现为头号瓶颈(32%)。torch.unique去重后只查唯一sign
再按逆索引展开,数学逐位等价(AUC不变),省最贵的大表随机gather。bench --dedup-emb。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-15 14:07:23 +08:00
parent 7f9cab05b5
commit 2268fa6cf3
2 changed files with 10 additions and 1 deletions
+3
View File
@@ -298,6 +298,7 @@ def _parse_args():
help="MoE实现:dense=向量化(新), loop=逐expert循环(原)")
ap.add_argument("--compile", action="store_true", help="开启 torch.compile")
ap.add_argument("--emb-fp16", action="store_true", help="Embedding表转FP16(查表带宽减半,测AUC)")
ap.add_argument("--dedup-emb", action="store_true", help="查表前对sign去重(减少大表随机访存)")
ap.add_argument("--profile", type=int, default=None, metavar="N",
help="剖析前 N 个 batch,打印按 CUDA 耗时排序的算子表(定位瓶颈)")
ap.add_argument("--rebuild", action="store_true", help="强制重建过滤缓存")
@@ -331,6 +332,8 @@ if __name__ == "__main__":
cfg["vectorize_moe"] = (a.moe == "dense")
if a.emb_fp16:
cfg["emb_fp16"] = True
if a.dedup_emb:
cfg["dedup_embedding"] = True
if a.compile:
cfg["compile"] = True
if a.profile is not None: