feat: 分块SDPA注意力(--attn chunked),按用户边界切块降O(S²)

每块~chunk_users个用户、块内因果SDPA(评测端已验证、无嵌套开销),sum(块S²)
远小于总S²。仅1次同步读切分边界。之前本地bs=16快13%被MoE同步吃掉,现MoE
同步已消除,切块红利应全露出。CONFIG.attn=chunked/chunk_users;等价测试已加。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-15 13:13:13 +08:00
parent 1249bbdbbc
commit 3d28f61a98
3 changed files with 60 additions and 5 deletions
+5 -2
View File
@@ -291,8 +291,9 @@ def _parse_args():
help="逗号分隔的 keep_fp32_modules,如 linear,rep_encoder.input_norm")
ap.add_argument("--feasign-none", action="store_true",
help="不截断特征(max_feasign_per_slot=None")
ap.add_argument("--attn", choices=["sdpa", "flex", "varlen"], default=None,
help="注意力:sdpa=稠密(原), flex=FlexAttention, varlen=嵌套张量变长flash")
ap.add_argument("--attn", choices=["sdpa", "chunked", "flex", "varlen"], default=None,
help="注意力:sdpa=稠密, chunked=按用户分块SDPA, flex/varlen=对照")
ap.add_argument("--chunk-users", type=int, default=None, help="chunked 模式每块用户数")
ap.add_argument("--moe", choices=["dense", "loop"], default=None,
help="MoE实现:dense=向量化(新), loop=逐expert循环(原)")
ap.add_argument("--compile", action="store_true", help="开启 torch.compile")
@@ -324,6 +325,8 @@ if __name__ == "__main__":
cfg["keep_fp32_modules"] = tuple(x for x in a.keep.split(",") if x)
if a.attn is not None:
cfg["attn"] = a.attn
if a.chunk_users is not None:
cfg["chunk_users"] = a.chunk_users
if a.moe is not None:
cfg["vectorize_moe"] = (a.moe == "dense")
if a.emb_fp16: