feat: F.embedding_bag 融合查表+池化(单kernel,免[M,512]中间) — 攻最大块(dedup index25%+segment11%=36%)
triton版profile:attention已优化出top,新大头=embedding池化36%+MoE22%+add18%。 embedding_bag一个kernel做查表+按段求和。等价测试+bench --emb-bag。默认关待验证。 Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
+7
-1
@@ -145,6 +145,7 @@ CONFIG = {
|
||||
"fuse_embedding": True, # True=28个slot的查表+池化融合为1次(减per-batch kernel启动)
|
||||
"syncfree_mask": True, # True=用searchsorted构造因果mask(无同步);False=repeat_interleave(同步)
|
||||
"emb_fp16": True, # True=Embedding表转FP16(查表带宽减半,实测AUC 0.75932≈无损)
|
||||
"use_embedding_bag": False, # True=用 F.embedding_bag 融合查表+池化(单kernel,免[M,512]中间),攻最大块
|
||||
"dedup_embedding": True, # True=查表前对sign去重(只查唯一值再展开),本地7.80->6.49s,AUC逐位等价
|
||||
"sparse_pool": False, # True=用(段×唯一)稀疏矩阵乘做池化,避免materialize整个[M,512](段内高重复时省)
|
||||
"compile": False, # 是否 torch.compile(实测慢5×,勿开)
|
||||
@@ -520,7 +521,12 @@ class RepEncoder(nn.Module):
|
||||
cat_values = self._signid(torch.cat(parts), max_idx)
|
||||
seg = torch.cat([torch.zeros(1, dtype=torch.long, device=cat_values.device),
|
||||
torch.cat(ends)]) # [28*N + 1]
|
||||
if CONFIG.get("sparse_pool", False):
|
||||
if CONFIG.get("use_embedding_bag", False):
|
||||
# F.embedding_bag 融合"查表+按段求和",单 kernel,免 [M,emb] 中间。
|
||||
pooled = F.embedding_bag(
|
||||
cat_values, self.emb.weight,
|
||||
offsets=seg[:-1].contiguous(), mode="sum").to(target_dtype)
|
||||
elif CONFIG.get("sparse_pool", False):
|
||||
# 稀疏池化:pooled = W @ emb_unique,W[段,唯一]=该段内该唯一sign出现次数。
|
||||
# 段内高重复(slot19)塌缩成单个带权项,避免 materialize 整个 [M,emb]。
|
||||
uniq, inv = torch.unique(cat_values, return_inverse=True)
|
||||
|
||||
Reference in New Issue
Block a user