feat: Triton varlen因果flash attention(块对角,单kernel,消逐块调用+mask构造开销)
每program处理(用户段query块,head),只遍历段内<=该块的key(因果),在线softmax, fp16读写fp32累加。CONFIG.attn=triton(默认仍chunked);_triton_block_meta每batch算一次 block→段映射8层复用;_resolve_attn在无triton/CPU时回退chunked。等价测试+bench --attn triton。 数学等价(FlashAttention同类,规则允许),不改组网。 Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
+2
-2
@@ -315,8 +315,8 @@ def _parse_args():
|
||||
help="逗号分隔的 keep_fp32_modules,如 linear,rep_encoder.input_norm")
|
||||
ap.add_argument("--feasign-none", action="store_true",
|
||||
help="不截断特征(max_feasign_per_slot=None)")
|
||||
ap.add_argument("--attn", choices=["sdpa", "chunked", "flex", "varlen"], default=None,
|
||||
help="注意力:sdpa=稠密, chunked=按用户分块SDPA, flex/varlen=对照")
|
||||
ap.add_argument("--attn", choices=["sdpa", "chunked", "triton", "flex", "varlen"], default=None,
|
||||
help="注意力:sdpa=稠密, chunked=分块SDPA, triton=varlen flash kernel, flex/varlen=对照")
|
||||
ap.add_argument("--chunk-users", type=int, default=None, help="chunked 模式每块用户数")
|
||||
ap.add_argument("--moe", choices=["dense", "loop"], default=None,
|
||||
help="MoE实现:dense=向量化(新), loop=逐expert循环(原)")
|
||||
|
||||
Reference in New Issue
Block a user