revert+feat: triton退回contiguous(去contiguous非连续读更慢) + embedding_bag默认开(消unique同步)

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-17 13:54:31 +08:00
parent 6114c78354
commit 6bb51a1057
+6 -7
View File
@@ -107,15 +107,14 @@ def _triton_varlen_attn(q, k, v, meta):
_, H, S, Dh = q.shape _, H, S, Dh = q.shape
cu, blk_seq, blk_inseq, total_blocks = meta cu, blk_seq, blk_inseq, total_blocks = meta
BLOCK_M = CONFIG.get("triton_block_m", 64) BLOCK_M = CONFIG.get("triton_block_m", 64)
# 不强制 contiguouskernel 用实际 stride 读非连续的 q/k/v(来自 qkv split+permute)。 # contiguous 后连续访存更快(实测去 contiguous 用 stride 读反而慢:非连续跨步读 > 一次性 clone)。
# q,k,v split 同源、stride 相同(k,v 含各自 storage_offsetTriton 用其 data_ptr 自动处理)。
out = torch.empty((1, H, S, Dh), device=q.device, dtype=torch.float16) out = torch.empty((1, H, S, Dh), device=q.device, dtype=torch.float16)
sqh, sqs, sqd = q.stride(1), q.stride(2), q.stride(3) qc = q.contiguous(); kc = k.contiguous(); vc = v.contiguous()
soh, sos, sod = out.stride(1), out.stride(2), out.stride(3) sh, ss, sd = S * Dh, Dh, 1
grid = (total_blocks, H) grid = (total_blocks, H)
_varlen_flash_fwd[grid]( _varlen_flash_fwd[grid](
q, k, v, out, cu, blk_seq, blk_inseq, qc, kc, vc, out, cu, blk_seq, blk_inseq,
sqh, sqs, sqd, soh, sos, sod, 1.0 / math.sqrt(Dh), cu.numel() - 1, sh, ss, sd, sh, ss, sd, 1.0 / math.sqrt(Dh), cu.numel() - 1,
BLOCK_M=BLOCK_M, BLOCK_N=64, D=Dh, BLOCK_M=BLOCK_M, BLOCK_N=64, D=Dh,
) )
return out return out
@@ -147,7 +146,7 @@ CONFIG = {
"fuse_embedding": True, # True=28个slot的查表+池化融合为1次(减per-batch kernel启动) "fuse_embedding": True, # True=28个slot的查表+池化融合为1次(减per-batch kernel启动)
"syncfree_mask": True, # True=用searchsorted构造因果mask(无同步)False=repeat_interleave(同步) "syncfree_mask": True, # True=用searchsorted构造因果mask(无同步)False=repeat_interleave(同步)
"emb_fp16": True, # True=Embedding表转FP16(查表带宽减半,实测AUC 0.75932≈无损) "emb_fp16": True, # True=Embedding表转FP16(查表带宽减半,实测AUC 0.75932≈无损)
"use_embedding_bag": False, # True=用 F.embedding_bag 融合查表+池化(单kernel,免[M,512]中间),攻最大块 "use_embedding_bag": True, # F.embedding_bag 融合查表+池化(单kernel,消dedup的unique同步,AUC≈无损)
"dedup_embedding": True, # True=查表前对sign去重(只查唯一值再展开),本地7.80->6.49s,AUC逐位等价 "dedup_embedding": True, # True=查表前对sign去重(只查唯一值再展开),本地7.80->6.49s,AUC逐位等价
"sparse_pool": False, # True=用(段×唯一)稀疏矩阵乘做池化,避免materialize整个[M,512](段内高重复时省) "sparse_pool": False, # True=用(段×唯一)稀疏矩阵乘做池化,避免materialize整个[M,512](段内高重复时省)
"compile": False, # 是否 torch.compile(实测慢5×,勿开) "compile": False, # 是否 torch.compile(实测慢5×,勿开)