feat: sparse_pool 选项 — (段×唯一)稀疏矩阵乘做池化,避免materialize[M,emb]

针对 profile 的 dedup展开(15%)+segment_reduce(6.6%)。段内高重复(slot19)塌缩
为单个带权项。CONFIG.sparse_pool;bench --sparse-pool;等价测试已加。默认关,待验证。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-15 15:15:13 +08:00
parent d5c327dc97
commit 6625666010
3 changed files with 50 additions and 5 deletions
+3
View File
@@ -299,6 +299,7 @@ def _parse_args():
ap.add_argument("--compile", action="store_true", help="开启 torch.compile")
ap.add_argument("--emb-fp16", action="store_true", help="Embedding表转FP16(查表带宽减半,测AUC)")
ap.add_argument("--dedup-emb", action="store_true", help="查表前对sign去重(减少大表随机访存)")
ap.add_argument("--sparse-pool", action="store_true", help="稀疏矩阵乘做池化(段内高重复时省)")
ap.add_argument("--profile", type=int, default=None, metavar="N",
help="剖析前 N 个 batch,打印按 CUDA 耗时排序的算子表(定位瓶颈)")
ap.add_argument("--rebuild", action="store_true", help="强制重建过滤缓存")
@@ -334,6 +335,8 @@ if __name__ == "__main__":
cfg["emb_fp16"] = True
if a.dedup_emb:
cfg["dedup_embedding"] = True
if a.sparse_pool:
cfg["sparse_pool"] = True
if a.compile:
cfg["compile"] = True
if a.profile is not None: