feat: 分块SDPA注意力(--attn chunked),按用户边界切块降O(S²)
每块~chunk_users个用户、块内因果SDPA(评测端已验证、无嵌套开销),sum(块S²) 远小于总S²。仅1次同步读切分边界。之前本地bs=16快13%被MoE同步吃掉,现MoE 同步已消除,切块红利应全露出。CONFIG.attn=chunked/chunk_users;等价测试已加。 Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
+5
-2
@@ -291,8 +291,9 @@ def _parse_args():
|
||||
help="逗号分隔的 keep_fp32_modules,如 linear,rep_encoder.input_norm")
|
||||
ap.add_argument("--feasign-none", action="store_true",
|
||||
help="不截断特征(max_feasign_per_slot=None)")
|
||||
ap.add_argument("--attn", choices=["sdpa", "flex", "varlen"], default=None,
|
||||
help="注意力:sdpa=稠密(原), flex=FlexAttention, varlen=嵌套张量变长flash")
|
||||
ap.add_argument("--attn", choices=["sdpa", "chunked", "flex", "varlen"], default=None,
|
||||
help="注意力:sdpa=稠密, chunked=按用户分块SDPA, flex/varlen=对照")
|
||||
ap.add_argument("--chunk-users", type=int, default=None, help="chunked 模式每块用户数")
|
||||
ap.add_argument("--moe", choices=["dense", "loop"], default=None,
|
||||
help="MoE实现:dense=向量化(新), loop=逐expert循环(原)")
|
||||
ap.add_argument("--compile", action="store_true", help="开启 torch.compile")
|
||||
@@ -324,6 +325,8 @@ if __name__ == "__main__":
|
||||
cfg["keep_fp32_modules"] = tuple(x for x in a.keep.split(",") if x)
|
||||
if a.attn is not None:
|
||||
cfg["attn"] = a.attn
|
||||
if a.chunk_users is not None:
|
||||
cfg["chunk_users"] = a.chunk_users
|
||||
if a.moe is not None:
|
||||
cfg["vectorize_moe"] = (a.moe == "dense")
|
||||
if a.emb_fp16:
|
||||
|
||||
+32
-3
@@ -42,8 +42,9 @@ CONFIG = {
|
||||
"filter_test_users": True, # 只处理含测试样本的用户(跳过会被丢弃的用户,省算力)
|
||||
# 实测:varlen 本地快(10.28s)但评测端慢(148s,嵌套张量构造开销随batch数放大)→已退回。
|
||||
# sdpa 是评测端验证最快(89.96s/58.86)。flex/compile/小batch/varlen 在评测端都更差。
|
||||
# attn: "sdpa"(稠密mask,默认/评测最优) / "varlen"(本地快评测慢) / "flex"(慢)
|
||||
# attn: "sdpa"(稠密mask) / "chunked"(按用户分块SDPA,降O(S²)) / "varlen"(评测慢) / "flex"(慢)
|
||||
"attn": "sdpa",
|
||||
"chunk_users": 16, # chunked 模式下每块的用户数(切小拼接序列以降注意力O(S²))
|
||||
# 稠密MoE去掉了 model(batch) 内唯一的同步点(MoE循环的.nonzero())。若评测计时不
|
||||
# synchronize,去掉同步点可能让被计时的 model(batch) 大幅缩短。本地force-sync看不出,
|
||||
# 须靠提交验证。AUC中性、MoE仅占2%算力故风险极低。
|
||||
@@ -408,10 +409,19 @@ def _varlen_attention(q, k, v, user_offsets):
|
||||
|
||||
def scaled_dot_product(q, k, v, extension):
|
||||
"""注意力分发:
|
||||
- varlen_offsets → 嵌套张量变长 flash(每用户独立序列、块对角因果,开销低)。
|
||||
- chunks → 按用户分块的 SDPA(每块块内因果,降 O(S²),无嵌套开销)。
|
||||
- varlen_offsets → 嵌套张量变长 flash(评测端慢,仅对照)。
|
||||
- block_mask → FlexAttention 块对角因果。
|
||||
- mask(默认) → 标准 SDPA 稠密 mask(数学等价、已验证最快)。
|
||||
"""
|
||||
if extension is not None and extension.get("chunks") is not None:
|
||||
outs = []
|
||||
for s0, s1, m in extension["chunks"]:
|
||||
outs.append(F.scaled_dot_product_attention(
|
||||
q[:, :, s0:s1], k[:, :, s0:s1], v[:, :, s0:s1],
|
||||
attn_mask=m, dropout_p=0.0, is_causal=False))
|
||||
return torch.cat(outs, dim=2)
|
||||
|
||||
if extension is not None and extension.get("varlen_offsets") is not None:
|
||||
return _varlen_attention(q, k, v, extension["varlen_offsets"])
|
||||
|
||||
@@ -603,6 +613,23 @@ class CTRModel(nn.Module):
|
||||
out_mask = torch.tril((a == 0).to(torch.int32)).bool()
|
||||
return out_mask
|
||||
|
||||
def build_chunks(self, user_offsets, device):
|
||||
"""把拼接序列按用户边界切成每块 ~chunk_users 个用户,返回 [(s0,s1,mask), ...]。
|
||||
每块块内因果,注意力 O(块内S²) 远小于 O(总S²)。仅 1 次同步(读切分边界)。"""
|
||||
chunk_users = int(CONFIG.get("chunk_users", 16))
|
||||
B = user_offsets.numel() - 1 # 用户数(读 shape,无同步)
|
||||
idx = list(range(0, B + 1, chunk_users))
|
||||
if idx[-1] != B:
|
||||
idx.append(B)
|
||||
bounds = user_offsets[idx].tolist() # 1 次同步:取各块的 token 边界
|
||||
chunks = []
|
||||
for c in range(len(bounds) - 1):
|
||||
s0, s1 = bounds[c], bounds[c + 1]
|
||||
local_off = user_offsets[idx[c]:idx[c + 1] + 1] - s0 # 该块内的用户边界(GPU)
|
||||
m = self.causal_mask_syncfree(local_off, s1 - s0, device).unsqueeze(0).unsqueeze(0)
|
||||
chunks.append((s0, s1, m))
|
||||
return chunks
|
||||
|
||||
def causal_mask_syncfree(self, user_offsets, S, device):
|
||||
"""与 get_sequence_causal_mask 等价,但用 searchsorted 求每个位置的用户号,
|
||||
避免 repeat_interleave(张量repeats) 的隐式同步。"""
|
||||
@@ -628,7 +655,9 @@ class CTRModel(nn.Module):
|
||||
seq_input = self.rep_encoder(batch)
|
||||
user_offsets = batch["user_offsets"]
|
||||
attn = _resolve_attn(seq_input.device)
|
||||
if attn == "varlen":
|
||||
if attn == "chunked":
|
||||
extension = {"chunks": self.build_chunks(user_offsets, seq_input.device)}
|
||||
elif attn == "varlen":
|
||||
extension = {"varlen_offsets": user_offsets}
|
||||
elif attn == "flex":
|
||||
S = seq_input.shape[0] # rep_encoder 输出 [S, D],S=总 token 数
|
||||
|
||||
@@ -64,6 +64,28 @@ def test_moe_dense_matches_loop():
|
||||
print(f"[PASS] MoE 稠密向量化 == 逐expert循环 (max err={err:.2e}, dev={dev})")
|
||||
|
||||
|
||||
def test_chunked_matches_dense_attention():
|
||||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
rep = infer.RepEncoder(vocab_size=100, emb_dim=8, slot_num=28, d_model=8)
|
||||
seq = infer.TransformerEncoder(d_model=8, n_heads=2, num_layers=1, dim_ff=16)
|
||||
model = infer.CTRModel(rep, seq, d_model=8).to(dev)
|
||||
torch.manual_seed(0)
|
||||
H, Dh = 8, 64
|
||||
offs = _offsets([10, 25, 7, 40, 18, 5, 33], dev) # 7 个用户
|
||||
S = int(offs[-1])
|
||||
q = torch.randn(1, H, S, Dh, device=dev)
|
||||
k = torch.randn(1, H, S, Dh, device=dev)
|
||||
v = torch.randn(1, H, S, Dh, device=dev)
|
||||
with torch.no_grad():
|
||||
dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]})
|
||||
infer.CONFIG["chunk_users"] = 3 # 每块 3 个用户
|
||||
chunks = model.build_chunks(offs, torch.device(dev))
|
||||
chunked = infer.scaled_dot_product(q, k, v, {"chunks": chunks})
|
||||
err = (dense - chunked).abs().max().item()
|
||||
assert torch.allclose(dense, chunked, atol=1e-4, rtol=1e-4), f"chunked 不等价 max err={err:.3e}"
|
||||
print(f"[PASS] chunked SDPA == 稠密SDPA (max err={err:.2e}, dev={dev})")
|
||||
|
||||
|
||||
def test_syncfree_mask_matches():
|
||||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
rep = infer.RepEncoder(vocab_size=100, emb_dim=8, slot_num=28, d_model=8)
|
||||
@@ -148,6 +170,7 @@ if __name__ == "__main__":
|
||||
test_moe_dense_matches_loop()
|
||||
test_fused_embedding_matches_perslot()
|
||||
test_syncfree_mask_matches()
|
||||
test_chunked_matches_dense_attention()
|
||||
test_varlen_matches_dense_attention()
|
||||
test_flex_matches_dense_attention()
|
||||
print("[DONE] 等价测试结束")
|
||||
|
||||
Reference in New Issue
Block a user