feat: sparse_pool 选项 — (段×唯一)稀疏矩阵乘做池化,避免materialize[M,emb]
针对 profile 的 dedup展开(15%)+segment_reduce(6.6%)。段内高重复(slot19)塌缩 为单个带权项。CONFIG.sparse_pool;bench --sparse-pool;等价测试已加。默认关,待验证。 Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -299,6 +299,7 @@ def _parse_args():
|
|||||||
ap.add_argument("--compile", action="store_true", help="开启 torch.compile")
|
ap.add_argument("--compile", action="store_true", help="开启 torch.compile")
|
||||||
ap.add_argument("--emb-fp16", action="store_true", help="Embedding表转FP16(查表带宽减半,测AUC)")
|
ap.add_argument("--emb-fp16", action="store_true", help="Embedding表转FP16(查表带宽减半,测AUC)")
|
||||||
ap.add_argument("--dedup-emb", action="store_true", help="查表前对sign去重(减少大表随机访存)")
|
ap.add_argument("--dedup-emb", action="store_true", help="查表前对sign去重(减少大表随机访存)")
|
||||||
|
ap.add_argument("--sparse-pool", action="store_true", help="稀疏矩阵乘做池化(段内高重复时省)")
|
||||||
ap.add_argument("--profile", type=int, default=None, metavar="N",
|
ap.add_argument("--profile", type=int, default=None, metavar="N",
|
||||||
help="剖析前 N 个 batch,打印按 CUDA 耗时排序的算子表(定位瓶颈)")
|
help="剖析前 N 个 batch,打印按 CUDA 耗时排序的算子表(定位瓶颈)")
|
||||||
ap.add_argument("--rebuild", action="store_true", help="强制重建过滤缓存")
|
ap.add_argument("--rebuild", action="store_true", help="强制重建过滤缓存")
|
||||||
@@ -334,6 +335,8 @@ if __name__ == "__main__":
|
|||||||
cfg["emb_fp16"] = True
|
cfg["emb_fp16"] = True
|
||||||
if a.dedup_emb:
|
if a.dedup_emb:
|
||||||
cfg["dedup_embedding"] = True
|
cfg["dedup_embedding"] = True
|
||||||
|
if a.sparse_pool:
|
||||||
|
cfg["sparse_pool"] = True
|
||||||
if a.compile:
|
if a.compile:
|
||||||
cfg["compile"] = True
|
cfg["compile"] = True
|
||||||
if a.profile is not None:
|
if a.profile is not None:
|
||||||
|
|||||||
+20
-5
@@ -53,6 +53,7 @@ CONFIG = {
|
|||||||
"syncfree_mask": True, # True=用searchsorted构造因果mask(无同步);False=repeat_interleave(同步)
|
"syncfree_mask": True, # True=用searchsorted构造因果mask(无同步);False=repeat_interleave(同步)
|
||||||
"emb_fp16": True, # True=Embedding表转FP16(查表带宽减半,实测AUC 0.75932≈无损)
|
"emb_fp16": True, # True=Embedding表转FP16(查表带宽减半,实测AUC 0.75932≈无损)
|
||||||
"dedup_embedding": True, # True=查表前对sign去重(只查唯一值再展开),本地7.80->6.49s,AUC逐位等价
|
"dedup_embedding": True, # True=查表前对sign去重(只查唯一值再展开),本地7.80->6.49s,AUC逐位等价
|
||||||
|
"sparse_pool": False, # True=用(段×唯一)稀疏矩阵乘做池化,避免materialize整个[M,512](段内高重复时省)
|
||||||
"compile": False, # 是否 torch.compile(实测慢5×,勿开)
|
"compile": False, # 是否 torch.compile(实测慢5×,勿开)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -381,13 +382,27 @@ class RepEncoder(nn.Module):
|
|||||||
cat_values = self._signid(torch.cat(parts), max_idx)
|
cat_values = self._signid(torch.cat(parts), max_idx)
|
||||||
seg = torch.cat([torch.zeros(1, dtype=torch.long, device=cat_values.device),
|
seg = torch.cat([torch.zeros(1, dtype=torch.long, device=cat_values.device),
|
||||||
torch.cat(ends)]) # [28*N + 1]
|
torch.cat(ends)]) # [28*N + 1]
|
||||||
if CONFIG.get("dedup_embedding", False):
|
if CONFIG.get("sparse_pool", False):
|
||||||
# 去重:只对唯一 sign 查大表,再按逆索引展开(数学逐位等价,省随机访存)
|
# 稀疏池化:pooled = W @ emb_unique,W[段,唯一]=该段内该唯一sign出现次数。
|
||||||
|
# 段内高重复(slot19)塌缩成单个带权项,避免 materialize 整个 [M,emb]。
|
||||||
uniq, inv = torch.unique(cat_values, return_inverse=True)
|
uniq, inv = torch.unique(cat_values, return_inverse=True)
|
||||||
emb = self.emb(uniq).to(target_dtype)[inv]
|
emb_unique = self.emb(uniq).float() # 小表;sparse.mm 用 fp32 稳
|
||||||
|
M = cat_values.numel()
|
||||||
|
num_seg = seg.numel() - 1
|
||||||
|
seg_id = torch.searchsorted(
|
||||||
|
seg, torch.arange(M, device=cat_values.device), right=True) - 1
|
||||||
|
W = torch.sparse_coo_tensor(
|
||||||
|
torch.stack([seg_id, inv]),
|
||||||
|
torch.ones(M, device=cat_values.device, dtype=torch.float32),
|
||||||
|
size=(num_seg, uniq.numel())).coalesce()
|
||||||
|
pooled = torch.sparse.mm(W, emb_unique).to(target_dtype) # [28*N, emb]
|
||||||
else:
|
else:
|
||||||
emb = self.emb(cat_values).to(target_dtype)
|
if CONFIG.get("dedup_embedding", False):
|
||||||
pooled = torch.segment_reduce(emb, reduce='sum', offsets=seg, initial=0) # [28*N, emb]
|
uniq, inv = torch.unique(cat_values, return_inverse=True)
|
||||||
|
emb = self.emb(uniq).to(target_dtype)[inv]
|
||||||
|
else:
|
||||||
|
emb = self.emb(cat_values).to(target_dtype)
|
||||||
|
pooled = torch.segment_reduce(emb, reduce='sum', offsets=seg, initial=0) # [28*N, emb]
|
||||||
pooled = pooled.view(self.slot_num, N, self.emb_dim).permute(1, 0, 2).reshape(
|
pooled = pooled.view(self.slot_num, N, self.emb_dim).permute(1, 0, 2).reshape(
|
||||||
N, self.slot_num * self.emb_dim)
|
N, self.slot_num * self.emb_dim)
|
||||||
return self.linear(self.input_norm(pooled))
|
return self.linear(self.input_norm(pooled))
|
||||||
|
|||||||
@@ -86,6 +86,32 @@ def test_chunked_matches_dense_attention():
|
|||||||
print(f"[PASS] chunked SDPA == 稠密SDPA (max err={err:.2e}, dev={dev})")
|
print(f"[PASS] chunked SDPA == 稠密SDPA (max err={err:.2e}, dev={dev})")
|
||||||
|
|
||||||
|
|
||||||
|
def test_sparse_pool_matches():
|
||||||
|
torch.manual_seed(0)
|
||||||
|
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
slot_num, emb_dim, d_model = 28, 512, 512
|
||||||
|
enc = infer.RepEncoder(vocab_size=200, emb_dim=emb_dim, slot_num=slot_num,
|
||||||
|
d_model=d_model).to(dev).eval()
|
||||||
|
N = 6
|
||||||
|
batch = {}
|
||||||
|
for s in range(1, slot_num + 1):
|
||||||
|
counts = torch.randint(0, 8, (N,))
|
||||||
|
# 故意制造段内重复:值域很小,重复率高
|
||||||
|
vals = torch.randint(0, 30, (int(counts.sum()),), device=dev)
|
||||||
|
offs = torch.cat([torch.zeros(1, dtype=torch.long), counts.cumsum(0)]).to(dev)
|
||||||
|
batch[s] = (vals, offs)
|
||||||
|
with torch.no_grad():
|
||||||
|
infer.CONFIG["sparse_pool"] = False
|
||||||
|
infer.CONFIG["dedup_embedding"] = True
|
||||||
|
ref = enc(batch)
|
||||||
|
infer.CONFIG["sparse_pool"] = True
|
||||||
|
new = enc(batch)
|
||||||
|
infer.CONFIG["sparse_pool"] = False
|
||||||
|
err = (ref - new).abs().max().item()
|
||||||
|
assert torch.allclose(ref, new, atol=2e-2, rtol=2e-2), f"sparse_pool 不等价 max err={err:.3e}"
|
||||||
|
print(f"[PASS] sparse_pool == segment_reduce (max err={err:.2e}, dev={dev})")
|
||||||
|
|
||||||
|
|
||||||
def test_syncfree_mask_matches():
|
def test_syncfree_mask_matches():
|
||||||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
rep = infer.RepEncoder(vocab_size=100, emb_dim=8, slot_num=28, d_model=8)
|
rep = infer.RepEncoder(vocab_size=100, emb_dim=8, slot_num=28, d_model=8)
|
||||||
@@ -169,6 +195,7 @@ def test_flex_matches_dense_attention():
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_moe_dense_matches_loop()
|
test_moe_dense_matches_loop()
|
||||||
test_fused_embedding_matches_perslot()
|
test_fused_embedding_matches_perslot()
|
||||||
|
test_sparse_pool_matches()
|
||||||
test_syncfree_mask_matches()
|
test_syncfree_mask_matches()
|
||||||
test_chunked_matches_dense_attention()
|
test_chunked_matches_dense_attention()
|
||||||
test_varlen_matches_dense_attention()
|
test_varlen_matches_dense_attention()
|
||||||
|
|||||||
Reference in New Issue
Block a user