feat: sparse_pool 选项 — (段×唯一)稀疏矩阵乘做池化,避免materialize[M,emb]

针对 profile 的 dedup展开(15%)+segment_reduce(6.6%)。段内高重复(slot19)塌缩
为单个带权项。CONFIG.sparse_pool;bench --sparse-pool;等价测试已加。默认关,待验证。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-15 15:15:13 +08:00
parent d5c327dc97
commit 6625666010
3 changed files with 50 additions and 5 deletions
+20 -5
View File
@@ -53,6 +53,7 @@ CONFIG = {
"syncfree_mask": True, # True=用searchsorted构造因果mask(无同步)False=repeat_interleave(同步)
"emb_fp16": True, # True=Embedding表转FP16(查表带宽减半,实测AUC 0.75932≈无损)
"dedup_embedding": True, # True=查表前对sign去重(只查唯一值再展开),本地7.80->6.49s,AUC逐位等价
"sparse_pool": False, # True=用(段×唯一)稀疏矩阵乘做池化,避免materialize整个[M,512](段内高重复时省)
"compile": False, # 是否 torch.compile(实测慢5×,勿开)
}
@@ -381,13 +382,27 @@ class RepEncoder(nn.Module):
cat_values = self._signid(torch.cat(parts), max_idx)
seg = torch.cat([torch.zeros(1, dtype=torch.long, device=cat_values.device),
torch.cat(ends)]) # [28*N + 1]
if CONFIG.get("dedup_embedding", False):
# 去重:只对唯一 sign 查大表,再按逆索引展开(数学逐位等价,省随机访存)
if CONFIG.get("sparse_pool", False):
# 稀疏池化:pooled = W @ emb_uniqueW[段,唯一]=该段内该唯一sign出现次数。
# 段内高重复(slot19)塌缩成单个带权项,避免 materialize 整个 [M,emb]。
uniq, inv = torch.unique(cat_values, return_inverse=True)
emb = self.emb(uniq).to(target_dtype)[inv]
emb_unique = self.emb(uniq).float() # 小表;sparse.mm 用 fp32 稳
M = cat_values.numel()
num_seg = seg.numel() - 1
seg_id = torch.searchsorted(
seg, torch.arange(M, device=cat_values.device), right=True) - 1
W = torch.sparse_coo_tensor(
torch.stack([seg_id, inv]),
torch.ones(M, device=cat_values.device, dtype=torch.float32),
size=(num_seg, uniq.numel())).coalesce()
pooled = torch.sparse.mm(W, emb_unique).to(target_dtype) # [28*N, emb]
else:
emb = self.emb(cat_values).to(target_dtype)
pooled = torch.segment_reduce(emb, reduce='sum', offsets=seg, initial=0) # [28*N, emb]
if CONFIG.get("dedup_embedding", False):
uniq, inv = torch.unique(cat_values, return_inverse=True)
emb = self.emb(uniq).to(target_dtype)[inv]
else:
emb = self.emb(cat_values).to(target_dtype)
pooled = torch.segment_reduce(emb, reduce='sum', offsets=seg, initial=0) # [28*N, emb]
pooled = pooled.view(self.slot_num, N, self.emb_dim).permute(1, 0, 2).reshape(
N, self.slot_num * self.emb_dim)
return self.linear(self.input_norm(pooled))