feat: dedup_embedding 选项 — 查表前对sign去重(slot19等高重复),减少大表随机访存
profile显示embedding查表现为头号瓶颈(32%)。torch.unique去重后只查唯一sign 再按逆索引展开,数学逐位等价(AUC不变),省最贵的大表随机gather。bench --dedup-emb。 Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -298,6 +298,7 @@ def _parse_args():
|
|||||||
help="MoE实现:dense=向量化(新), loop=逐expert循环(原)")
|
help="MoE实现:dense=向量化(新), loop=逐expert循环(原)")
|
||||||
ap.add_argument("--compile", action="store_true", help="开启 torch.compile")
|
ap.add_argument("--compile", action="store_true", help="开启 torch.compile")
|
||||||
ap.add_argument("--emb-fp16", action="store_true", help="Embedding表转FP16(查表带宽减半,测AUC)")
|
ap.add_argument("--emb-fp16", action="store_true", help="Embedding表转FP16(查表带宽减半,测AUC)")
|
||||||
|
ap.add_argument("--dedup-emb", action="store_true", help="查表前对sign去重(减少大表随机访存)")
|
||||||
ap.add_argument("--profile", type=int, default=None, metavar="N",
|
ap.add_argument("--profile", type=int, default=None, metavar="N",
|
||||||
help="剖析前 N 个 batch,打印按 CUDA 耗时排序的算子表(定位瓶颈)")
|
help="剖析前 N 个 batch,打印按 CUDA 耗时排序的算子表(定位瓶颈)")
|
||||||
ap.add_argument("--rebuild", action="store_true", help="强制重建过滤缓存")
|
ap.add_argument("--rebuild", action="store_true", help="强制重建过滤缓存")
|
||||||
@@ -331,6 +332,8 @@ if __name__ == "__main__":
|
|||||||
cfg["vectorize_moe"] = (a.moe == "dense")
|
cfg["vectorize_moe"] = (a.moe == "dense")
|
||||||
if a.emb_fp16:
|
if a.emb_fp16:
|
||||||
cfg["emb_fp16"] = True
|
cfg["emb_fp16"] = True
|
||||||
|
if a.dedup_emb:
|
||||||
|
cfg["dedup_embedding"] = True
|
||||||
if a.compile:
|
if a.compile:
|
||||||
cfg["compile"] = True
|
cfg["compile"] = True
|
||||||
if a.profile is not None:
|
if a.profile is not None:
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ CONFIG = {
|
|||||||
"fuse_embedding": True, # True=28个slot的查表+池化融合为1次(减per-batch kernel启动)
|
"fuse_embedding": True, # True=28个slot的查表+池化融合为1次(减per-batch kernel启动)
|
||||||
"syncfree_mask": True, # True=用searchsorted构造因果mask(无同步);False=repeat_interleave(同步)
|
"syncfree_mask": True, # True=用searchsorted构造因果mask(无同步);False=repeat_interleave(同步)
|
||||||
"emb_fp16": True, # True=Embedding表转FP16(查表带宽减半,实测AUC 0.75932≈无损)
|
"emb_fp16": True, # True=Embedding表转FP16(查表带宽减半,实测AUC 0.75932≈无损)
|
||||||
|
"dedup_embedding": False, # True=查表前对sign去重(只查唯一值再展开),减少大表随机访存。数学等价
|
||||||
"compile": False, # 是否 torch.compile(实测慢5×,勿开)
|
"compile": False, # 是否 torch.compile(实测慢5×,勿开)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -380,6 +381,11 @@ class RepEncoder(nn.Module):
|
|||||||
cat_values = self._signid(torch.cat(parts), max_idx)
|
cat_values = self._signid(torch.cat(parts), max_idx)
|
||||||
seg = torch.cat([torch.zeros(1, dtype=torch.long, device=cat_values.device),
|
seg = torch.cat([torch.zeros(1, dtype=torch.long, device=cat_values.device),
|
||||||
torch.cat(ends)]) # [28*N + 1]
|
torch.cat(ends)]) # [28*N + 1]
|
||||||
|
if CONFIG.get("dedup_embedding", False):
|
||||||
|
# 去重:只对唯一 sign 查大表,再按逆索引展开(数学逐位等价,省随机访存)
|
||||||
|
uniq, inv = torch.unique(cat_values, return_inverse=True)
|
||||||
|
emb = self.emb(uniq).to(target_dtype)[inv]
|
||||||
|
else:
|
||||||
emb = self.emb(cat_values).to(target_dtype)
|
emb = self.emb(cat_values).to(target_dtype)
|
||||||
pooled = torch.segment_reduce(emb, reduce='sum', offsets=seg, initial=0) # [28*N, emb]
|
pooled = torch.segment_reduce(emb, reduce='sum', offsets=seg, initial=0) # [28*N, emb]
|
||||||
pooled = pooled.view(self.slot_num, N, self.emb_dim).permute(1, 0, 2).reshape(
|
pooled = pooled.view(self.slot_num, N, self.emb_dim).permute(1, 0, 2).reshape(
|
||||||
|
|||||||
Reference in New Issue
Block a user