feat: 嵌套张量变长 flash 注意力(--attn varlen),统一 CONFIG.attn 分发
每用户当独立序列、is_causal 块对角因果,一个 flash 内核处理一 batch 内所有
用户,无稠密mask/无padding浪费/开销远低于FlexAttention。CONFIG.attn∈
{sdpa(默认),flex,varlen};bench --attn varlen;test_equiv 加 varlen 等价测试。
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
+3
-3
@@ -291,8 +291,8 @@ def _parse_args():
|
||||
help="逗号分隔的 keep_fp32_modules,如 linear,rep_encoder.input_norm")
|
||||
ap.add_argument("--feasign-none", action="store_true",
|
||||
help="不截断特征(max_feasign_per_slot=None)")
|
||||
ap.add_argument("--attn", choices=["auto", "flex", "sdpa"], default=None,
|
||||
help="注意力实现:flex=块对角FlexAttention, sdpa=稠密(原), auto=SM80自动")
|
||||
ap.add_argument("--attn", choices=["sdpa", "flex", "varlen"], default=None,
|
||||
help="注意力:sdpa=稠密(原), flex=FlexAttention, varlen=嵌套张量变长flash")
|
||||
ap.add_argument("--moe", choices=["dense", "loop"], default=None,
|
||||
help="MoE实现:dense=向量化(新), loop=逐expert循环(原)")
|
||||
ap.add_argument("--compile", action="store_true", help="开启 torch.compile")
|
||||
@@ -322,7 +322,7 @@ if __name__ == "__main__":
|
||||
if a.keep is not None:
|
||||
cfg["keep_fp32_modules"] = tuple(x for x in a.keep.split(",") if x)
|
||||
if a.attn is not None:
|
||||
cfg["use_flex_attn"] = {"auto": "auto", "flex": True, "sdpa": False}[a.attn]
|
||||
cfg["attn"] = a.attn
|
||||
if a.moe is not None:
|
||||
cfg["vectorize_moe"] = (a.moe == "dense")
|
||||
if a.compile:
|
||||
|
||||
Reference in New Issue
Block a user