feat: collate段内去重+计数 → embedding_bag per_sample_weights(减查表带宽,数学等价)

collate(不计时)把段内重复sign折叠成(唯一,次数),embedding_bag用per_sample_weights=次数。
slot19等高重复段读量大降。攻最大块(embedding_bag 37%带宽)。走已验证的slot key通路(非新key)。
等价测试+bench --collate-dedup。默认关待验证。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-20 14:46:48 +08:00
parent 9461d97173
commit cc4acca875
3 changed files with 73 additions and 11 deletions
+39 -11
View File
@@ -168,6 +168,7 @@ CONFIG = {
"syncfree_mask": True, # True=用searchsorted构造因果mask(无同步)False=repeat_interleave(同步)
"emb_fp16": True, # True=Embedding表转FP16(查表带宽减半,实测AUC 0.75932≈无损)
"use_embedding_bag": True, # F.embedding_bag 融合查表+池化(单kernel,消dedup的unique同步,AUC≈无损)
"collate_dedup": False, # collate(不计时)段内去重+计数→embedding_bag per_sample_weights,减查表带宽(数学等价)
"dedup_embedding": True, # True=查表前对sign去重(只查唯一值再展开),本地7.80->6.49s,AUC逐位等价
"sparse_pool": False, # True=用(段×唯一)稀疏矩阵乘做池化,避免materialize整个[M,512](段内高重复时省)
"compile": False, # 是否 torch.compile(实测慢5×,勿开)
@@ -428,17 +429,39 @@ def make_collate_fn(max_slot_id):
user_offsets.append(len(all_labels))
slot_data = {}
dedup = CONFIG.get("collate_dedup", False)
for slot in range(1, max_slot_id + 1):
values = []
offsets = [0]
for feasign in all_feasigns:
if slot in feasign:
values.extend(feasign[slot])
offsets.append(len(values))
slot_data[slot] = (
torch.tensor(values, dtype=torch.long),
torch.tensor(offsets, dtype=torch.long),
)
if dedup:
# 段内去重+计数(不计时):重复 sign 折叠成 (唯一sign, 次数)
# 配合 embedding_bag(per_sample_weights=次数) 数学等价、减查表带宽。
weights = []
for feasign in all_feasigns:
if slot in feasign:
sg = feasign[slot]
if len(sg) > 3: # 只对长段去重,省 collate 开销
uniq, cnt = np.unique(np.asarray(sg), return_counts=True)
values.extend(uniq.tolist())
weights.extend(cnt.tolist())
else:
values.extend(sg)
weights.extend([1] * len(sg))
offsets.append(len(values))
slot_data[slot] = (
torch.tensor(values, dtype=torch.long),
torch.tensor(offsets, dtype=torch.long),
torch.tensor(weights, dtype=torch.float32),
)
else:
for feasign in all_feasigns:
if slot in feasign:
values.extend(feasign[slot])
offsets.append(len(values))
slot_data[slot] = (
torch.tensor(values, dtype=torch.long),
torch.tensor(offsets, dtype=torch.long),
)
result = {
'userid': torch.tensor(all_userids, dtype=torch.long),
@@ -534,20 +557,25 @@ class RepEncoder(nn.Module):
# 把 28 个 slot 的 values 拼成一条,offsets 平移拼成覆盖 28*N 段的单一 offsets
parts, ends, base = [], [], 0
wparts = [] # collate_dedup 时各 slot 的 per_sample_weights
for i in range(self.slot_num):
values, offsets = batch[i + 1]
sd = batch[i + 1]
values, offsets = sd[0], sd[1]
offsets = offsets.to(values.device)
parts.append(values)
ends.append(offsets[1:] + base) # 该 slot 各样本的段尾(平移 base)
base += values.numel() # numel 读 shape,不触发同步
if len(sd) > 2:
wparts.append(sd[2])
cat_values = self._signid(torch.cat(parts), max_idx)
seg = torch.cat([torch.zeros(1, dtype=torch.long, device=cat_values.device),
torch.cat(ends)]) # [28*N + 1]
if CONFIG.get("use_embedding_bag", False):
# F.embedding_bag 融合"查表+按段求和",单 kernel,免 [M,emb] 中间。
psw = torch.cat(wparts).to(self.emb.weight.dtype) if wparts else None
pooled = F.embedding_bag(
cat_values, self.emb.weight,
offsets=seg[:-1].contiguous(), mode="sum").to(target_dtype)
cat_values, self.emb.weight, offsets=seg[:-1].contiguous(),
per_sample_weights=psw, mode="sum").to(target_dtype)
elif CONFIG.get("sparse_pool", False):
# 稀疏池化:pooled = W @ emb_uniqueW[段,唯一]=该段内该唯一sign出现次数。
# 段内高重复(slot19)塌缩成单个带权项,避免 materialize 整个 [M,emb]。