feat: 分块SDPA注意力(--attn chunked),按用户边界切块降O(S²)
每块~chunk_users个用户、块内因果SDPA(评测端已验证、无嵌套开销),sum(块S²) 远小于总S²。仅1次同步读切分边界。之前本地bs=16快13%被MoE同步吃掉,现MoE 同步已消除,切块红利应全露出。CONFIG.attn=chunked/chunk_users;等价测试已加。 Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
+5
-2
@@ -291,8 +291,9 @@ def _parse_args():
|
||||
help="逗号分隔的 keep_fp32_modules,如 linear,rep_encoder.input_norm")
|
||||
ap.add_argument("--feasign-none", action="store_true",
|
||||
help="不截断特征(max_feasign_per_slot=None)")
|
||||
ap.add_argument("--attn", choices=["sdpa", "flex", "varlen"], default=None,
|
||||
help="注意力:sdpa=稠密(原), flex=FlexAttention, varlen=嵌套张量变长flash")
|
||||
ap.add_argument("--attn", choices=["sdpa", "chunked", "flex", "varlen"], default=None,
|
||||
help="注意力:sdpa=稠密, chunked=按用户分块SDPA, flex/varlen=对照")
|
||||
ap.add_argument("--chunk-users", type=int, default=None, help="chunked 模式每块用户数")
|
||||
ap.add_argument("--moe", choices=["dense", "loop"], default=None,
|
||||
help="MoE实现:dense=向量化(新), loop=逐expert循环(原)")
|
||||
ap.add_argument("--compile", action="store_true", help="开启 torch.compile")
|
||||
@@ -324,6 +325,8 @@ if __name__ == "__main__":
|
||||
cfg["keep_fp32_modules"] = tuple(x for x in a.keep.split(",") if x)
|
||||
if a.attn is not None:
|
||||
cfg["attn"] = a.attn
|
||||
if a.chunk_users is not None:
|
||||
cfg["chunk_users"] = a.chunk_users
|
||||
if a.moe is not None:
|
||||
cfg["vectorize_moe"] = (a.moe == "dense")
|
||||
if a.emb_fp16:
|
||||
|
||||
Reference in New Issue
Block a user