feat: collate段内去重+计数 → embedding_bag per_sample_weights(减查表带宽,数学等价)
collate(不计时)把段内重复sign折叠成(唯一,次数),embedding_bag用per_sample_weights=次数。 slot19等高重复段读量大降。攻最大块(embedding_bag 37%带宽)。走已验证的slot key通路(非新key)。 等价测试+bench --collate-dedup。默认关待验证。 Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -344,6 +344,7 @@ def _parse_args():
|
|||||||
ap.add_argument("--emb-fp16", action="store_true", help="Embedding表转FP16(查表带宽减半,测AUC)")
|
ap.add_argument("--emb-fp16", action="store_true", help="Embedding表转FP16(查表带宽减半,测AUC)")
|
||||||
ap.add_argument("--dedup-emb", action="store_true", help="查表前对sign去重(减少大表随机访存)")
|
ap.add_argument("--dedup-emb", action="store_true", help="查表前对sign去重(减少大表随机访存)")
|
||||||
ap.add_argument("--emb-bag", action="store_true", help="F.embedding_bag 融合查表+池化")
|
ap.add_argument("--emb-bag", action="store_true", help="F.embedding_bag 融合查表+池化")
|
||||||
|
ap.add_argument("--collate-dedup", action="store_true", help="collate段内去重+计数(减查表带宽)")
|
||||||
ap.add_argument("--no-moe-baddbmm", action="store_true", help="关闭 MoE baddbmm(用 einsum 对照)")
|
ap.add_argument("--no-moe-baddbmm", action="store_true", help="关闭 MoE baddbmm(用 einsum 对照)")
|
||||||
ap.add_argument("--no-skip-moe-loss", action="store_true", help="不跳过 moe_loss(对照)")
|
ap.add_argument("--no-skip-moe-loss", action="store_true", help="不跳过 moe_loss(对照)")
|
||||||
ap.add_argument("--logit-bias", type=float, default=None, help="PCOC校准:logit偏移(本地验证PCOC→1.0)")
|
ap.add_argument("--logit-bias", type=float, default=None, help="PCOC校准:logit偏移(本地验证PCOC→1.0)")
|
||||||
@@ -398,6 +399,8 @@ if __name__ == "__main__":
|
|||||||
cfg["dedup_embedding"] = True
|
cfg["dedup_embedding"] = True
|
||||||
if a.emb_bag:
|
if a.emb_bag:
|
||||||
cfg["use_embedding_bag"] = True
|
cfg["use_embedding_bag"] = True
|
||||||
|
if a.collate_dedup:
|
||||||
|
cfg["collate_dedup"] = True
|
||||||
if a.no_moe_baddbmm:
|
if a.no_moe_baddbmm:
|
||||||
cfg["moe_baddbmm"] = False
|
cfg["moe_baddbmm"] = False
|
||||||
if a.no_skip_moe_loss:
|
if a.no_skip_moe_loss:
|
||||||
|
|||||||
+31
-3
@@ -168,6 +168,7 @@ CONFIG = {
|
|||||||
"syncfree_mask": True, # True=用searchsorted构造因果mask(无同步);False=repeat_interleave(同步)
|
"syncfree_mask": True, # True=用searchsorted构造因果mask(无同步);False=repeat_interleave(同步)
|
||||||
"emb_fp16": True, # True=Embedding表转FP16(查表带宽减半,实测AUC 0.75932≈无损)
|
"emb_fp16": True, # True=Embedding表转FP16(查表带宽减半,实测AUC 0.75932≈无损)
|
||||||
"use_embedding_bag": True, # F.embedding_bag 融合查表+池化(单kernel,消dedup的unique同步,AUC≈无损)
|
"use_embedding_bag": True, # F.embedding_bag 融合查表+池化(单kernel,消dedup的unique同步,AUC≈无损)
|
||||||
|
"collate_dedup": False, # collate(不计时)段内去重+计数→embedding_bag per_sample_weights,减查表带宽(数学等价)
|
||||||
"dedup_embedding": True, # True=查表前对sign去重(只查唯一值再展开),本地7.80->6.49s,AUC逐位等价
|
"dedup_embedding": True, # True=查表前对sign去重(只查唯一值再展开),本地7.80->6.49s,AUC逐位等价
|
||||||
"sparse_pool": False, # True=用(段×唯一)稀疏矩阵乘做池化,避免materialize整个[M,512](段内高重复时省)
|
"sparse_pool": False, # True=用(段×唯一)稀疏矩阵乘做池化,避免materialize整个[M,512](段内高重复时省)
|
||||||
"compile": False, # 是否 torch.compile(实测慢5×,勿开)
|
"compile": False, # 是否 torch.compile(实测慢5×,勿开)
|
||||||
@@ -428,9 +429,31 @@ def make_collate_fn(max_slot_id):
|
|||||||
user_offsets.append(len(all_labels))
|
user_offsets.append(len(all_labels))
|
||||||
|
|
||||||
slot_data = {}
|
slot_data = {}
|
||||||
|
dedup = CONFIG.get("collate_dedup", False)
|
||||||
for slot in range(1, max_slot_id + 1):
|
for slot in range(1, max_slot_id + 1):
|
||||||
values = []
|
values = []
|
||||||
offsets = [0]
|
offsets = [0]
|
||||||
|
if dedup:
|
||||||
|
# 段内去重+计数(不计时):重复 sign 折叠成 (唯一sign, 次数),
|
||||||
|
# 配合 embedding_bag(per_sample_weights=次数) 数学等价、减查表带宽。
|
||||||
|
weights = []
|
||||||
|
for feasign in all_feasigns:
|
||||||
|
if slot in feasign:
|
||||||
|
sg = feasign[slot]
|
||||||
|
if len(sg) > 3: # 只对长段去重,省 collate 开销
|
||||||
|
uniq, cnt = np.unique(np.asarray(sg), return_counts=True)
|
||||||
|
values.extend(uniq.tolist())
|
||||||
|
weights.extend(cnt.tolist())
|
||||||
|
else:
|
||||||
|
values.extend(sg)
|
||||||
|
weights.extend([1] * len(sg))
|
||||||
|
offsets.append(len(values))
|
||||||
|
slot_data[slot] = (
|
||||||
|
torch.tensor(values, dtype=torch.long),
|
||||||
|
torch.tensor(offsets, dtype=torch.long),
|
||||||
|
torch.tensor(weights, dtype=torch.float32),
|
||||||
|
)
|
||||||
|
else:
|
||||||
for feasign in all_feasigns:
|
for feasign in all_feasigns:
|
||||||
if slot in feasign:
|
if slot in feasign:
|
||||||
values.extend(feasign[slot])
|
values.extend(feasign[slot])
|
||||||
@@ -534,20 +557,25 @@ class RepEncoder(nn.Module):
|
|||||||
|
|
||||||
# 把 28 个 slot 的 values 拼成一条,offsets 平移拼成覆盖 28*N 段的单一 offsets
|
# 把 28 个 slot 的 values 拼成一条,offsets 平移拼成覆盖 28*N 段的单一 offsets
|
||||||
parts, ends, base = [], [], 0
|
parts, ends, base = [], [], 0
|
||||||
|
wparts = [] # collate_dedup 时各 slot 的 per_sample_weights
|
||||||
for i in range(self.slot_num):
|
for i in range(self.slot_num):
|
||||||
values, offsets = batch[i + 1]
|
sd = batch[i + 1]
|
||||||
|
values, offsets = sd[0], sd[1]
|
||||||
offsets = offsets.to(values.device)
|
offsets = offsets.to(values.device)
|
||||||
parts.append(values)
|
parts.append(values)
|
||||||
ends.append(offsets[1:] + base) # 该 slot 各样本的段尾(平移 base)
|
ends.append(offsets[1:] + base) # 该 slot 各样本的段尾(平移 base)
|
||||||
base += values.numel() # numel 读 shape,不触发同步
|
base += values.numel() # numel 读 shape,不触发同步
|
||||||
|
if len(sd) > 2:
|
||||||
|
wparts.append(sd[2])
|
||||||
cat_values = self._signid(torch.cat(parts), max_idx)
|
cat_values = self._signid(torch.cat(parts), max_idx)
|
||||||
seg = torch.cat([torch.zeros(1, dtype=torch.long, device=cat_values.device),
|
seg = torch.cat([torch.zeros(1, dtype=torch.long, device=cat_values.device),
|
||||||
torch.cat(ends)]) # [28*N + 1]
|
torch.cat(ends)]) # [28*N + 1]
|
||||||
if CONFIG.get("use_embedding_bag", False):
|
if CONFIG.get("use_embedding_bag", False):
|
||||||
# F.embedding_bag 融合"查表+按段求和",单 kernel,免 [M,emb] 中间。
|
# F.embedding_bag 融合"查表+按段求和",单 kernel,免 [M,emb] 中间。
|
||||||
|
psw = torch.cat(wparts).to(self.emb.weight.dtype) if wparts else None
|
||||||
pooled = F.embedding_bag(
|
pooled = F.embedding_bag(
|
||||||
cat_values, self.emb.weight,
|
cat_values, self.emb.weight, offsets=seg[:-1].contiguous(),
|
||||||
offsets=seg[:-1].contiguous(), mode="sum").to(target_dtype)
|
per_sample_weights=psw, mode="sum").to(target_dtype)
|
||||||
elif CONFIG.get("sparse_pool", False):
|
elif CONFIG.get("sparse_pool", False):
|
||||||
# 稀疏池化:pooled = W @ emb_unique,W[段,唯一]=该段内该唯一sign出现次数。
|
# 稀疏池化:pooled = W @ emb_unique,W[段,唯一]=该段内该唯一sign出现次数。
|
||||||
# 段内高重复(slot19)塌缩成单个带权项,避免 materialize 整个 [M,emb]。
|
# 段内高重复(slot19)塌缩成单个带权项,避免 materialize 整个 [M,emb]。
|
||||||
|
|||||||
@@ -86,6 +86,36 @@ def test_chunked_matches_dense_attention():
|
|||||||
print(f"[PASS] chunked SDPA == 稠密SDPA (max err={err:.2e}, dev={dev})")
|
print(f"[PASS] chunked SDPA == 稠密SDPA (max err={err:.2e}, dev={dev})")
|
||||||
|
|
||||||
|
|
||||||
|
def test_collate_dedup_matches():
|
||||||
|
import numpy as _np
|
||||||
|
torch.manual_seed(0)
|
||||||
|
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
enc = infer.RepEncoder(vocab_size=200, emb_dim=512, slot_num=28, d_model=512).to(dev).eval()
|
||||||
|
N = 5
|
||||||
|
plain, dedup = {}, {}
|
||||||
|
for s in range(1, 29):
|
||||||
|
seg_vals, offs_p = [], [0]
|
||||||
|
u_vals, u_w, offs_d = [], [], [0]
|
||||||
|
for _ in range(N):
|
||||||
|
m = int(torch.randint(1, 8, (1,)))
|
||||||
|
signs = torch.randint(0, 200, (m,)).tolist()
|
||||||
|
signs = signs + signs[:max(0, m - 1)] # 制造段内重复
|
||||||
|
seg_vals.extend(signs); offs_p.append(len(seg_vals))
|
||||||
|
uq, ct = _np.unique(_np.asarray(signs), return_counts=True)
|
||||||
|
u_vals.extend(uq.tolist()); u_w.extend(ct.tolist()); offs_d.append(len(u_vals))
|
||||||
|
plain[s] = (torch.tensor(seg_vals, device=dev), torch.tensor(offs_p, device=dev))
|
||||||
|
dedup[s] = (torch.tensor(u_vals, device=dev), torch.tensor(offs_d, device=dev),
|
||||||
|
torch.tensor(u_w, dtype=torch.float32, device=dev))
|
||||||
|
with torch.no_grad():
|
||||||
|
infer.CONFIG["use_embedding_bag"] = True
|
||||||
|
ref = enc(plain)
|
||||||
|
new = enc(dedup)
|
||||||
|
infer.CONFIG["use_embedding_bag"] = False
|
||||||
|
err = (ref - new).abs().max().item()
|
||||||
|
assert torch.allclose(ref, new, atol=1e-3, rtol=1e-3), f"collate_dedup 不等价 max err={err:.3e}"
|
||||||
|
print(f"[PASS] collate_dedup(去重+计数) == 全展开 (max err={err:.2e}, dev={dev})")
|
||||||
|
|
||||||
|
|
||||||
def test_embedding_bag_matches():
|
def test_embedding_bag_matches():
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
@@ -262,6 +292,7 @@ if __name__ == "__main__":
|
|||||||
test_sparse_moe_matches_dense()
|
test_sparse_moe_matches_dense()
|
||||||
test_fused_embedding_matches_perslot()
|
test_fused_embedding_matches_perslot()
|
||||||
test_embedding_bag_matches()
|
test_embedding_bag_matches()
|
||||||
|
test_collate_dedup_matches()
|
||||||
test_sparse_pool_matches()
|
test_sparse_pool_matches()
|
||||||
test_syncfree_mask_matches()
|
test_syncfree_mask_matches()
|
||||||
test_triton_varlen_matches_dense()
|
test_triton_varlen_matches_dense()
|
||||||
|
|||||||
Reference in New Issue
Block a user