feat: Triton varlen因果flash attention(块对角,单kernel,消逐块调用+mask构造开销)

每program处理(用户段query块,head),只遍历段内<=该块的key(因果),在线softmax,
fp16读写fp32累加。CONFIG.attn=triton(默认仍chunked);_triton_block_meta每batch算一次
block→段映射8层复用;_resolve_attn在无triton/CPU时回退chunked。等价测试+bench --attn triton。
数学等价(FlashAttention同类,规则允许),不改组网。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-17 00:14:53 +08:00
parent a5ee660523
commit cdc2dd490b
3 changed files with 131 additions and 5 deletions
+2 -2
View File
@@ -315,8 +315,8 @@ def _parse_args():
help="逗号分隔的 keep_fp32_modules,如 linear,rep_encoder.input_norm")
ap.add_argument("--feasign-none", action="store_true",
help="不截断特征(max_feasign_per_slot=None")
ap.add_argument("--attn", choices=["sdpa", "chunked", "flex", "varlen"], default=None,
help="注意力:sdpa=稠密, chunked=按用户分块SDPA, flex/varlen=对照")
ap.add_argument("--attn", choices=["sdpa", "chunked", "triton", "flex", "varlen"], default=None,
help="注意力:sdpa=稠密, chunked=分块SDPA, triton=varlen flash kernel, flex/varlen=对照")
ap.add_argument("--chunk-users", type=int, default=None, help="chunked 模式每块用户数")
ap.add_argument("--moe", choices=["dense", "loop"], default=None,
help="MoE实现:dense=向量化(新), loop=逐expert循环(原)")