revert+feat: triton退回contiguous(去contiguous非连续读更慢) + embedding_bag默认开(消unique同步)
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
+6
-7
@@ -107,15 +107,14 @@ def _triton_varlen_attn(q, k, v, meta):
|
|||||||
_, H, S, Dh = q.shape
|
_, H, S, Dh = q.shape
|
||||||
cu, blk_seq, blk_inseq, total_blocks = meta
|
cu, blk_seq, blk_inseq, total_blocks = meta
|
||||||
BLOCK_M = CONFIG.get("triton_block_m", 64)
|
BLOCK_M = CONFIG.get("triton_block_m", 64)
|
||||||
# 不强制 contiguous:kernel 用实际 stride 读非连续的 q/k/v(来自 qkv split+permute)。
|
# contiguous 后连续访存更快(实测去 contiguous 用 stride 读反而慢:非连续跨步读 > 一次性 clone)。
|
||||||
# q,k,v split 同源、stride 相同(k,v 含各自 storage_offset,Triton 用其 data_ptr 自动处理)。
|
|
||||||
out = torch.empty((1, H, S, Dh), device=q.device, dtype=torch.float16)
|
out = torch.empty((1, H, S, Dh), device=q.device, dtype=torch.float16)
|
||||||
sqh, sqs, sqd = q.stride(1), q.stride(2), q.stride(3)
|
qc = q.contiguous(); kc = k.contiguous(); vc = v.contiguous()
|
||||||
soh, sos, sod = out.stride(1), out.stride(2), out.stride(3)
|
sh, ss, sd = S * Dh, Dh, 1
|
||||||
grid = (total_blocks, H)
|
grid = (total_blocks, H)
|
||||||
_varlen_flash_fwd[grid](
|
_varlen_flash_fwd[grid](
|
||||||
q, k, v, out, cu, blk_seq, blk_inseq,
|
qc, kc, vc, out, cu, blk_seq, blk_inseq,
|
||||||
sqh, sqs, sqd, soh, sos, sod, 1.0 / math.sqrt(Dh), cu.numel() - 1,
|
sh, ss, sd, sh, ss, sd, 1.0 / math.sqrt(Dh), cu.numel() - 1,
|
||||||
BLOCK_M=BLOCK_M, BLOCK_N=64, D=Dh,
|
BLOCK_M=BLOCK_M, BLOCK_N=64, D=Dh,
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
@@ -147,7 +146,7 @@ CONFIG = {
|
|||||||
"fuse_embedding": True, # True=28个slot的查表+池化融合为1次(减per-batch kernel启动)
|
"fuse_embedding": True, # True=28个slot的查表+池化融合为1次(减per-batch kernel启动)
|
||||||
"syncfree_mask": True, # True=用searchsorted构造因果mask(无同步);False=repeat_interleave(同步)
|
"syncfree_mask": True, # True=用searchsorted构造因果mask(无同步);False=repeat_interleave(同步)
|
||||||
"emb_fp16": True, # True=Embedding表转FP16(查表带宽减半,实测AUC 0.75932≈无损)
|
"emb_fp16": True, # True=Embedding表转FP16(查表带宽减半,实测AUC 0.75932≈无损)
|
||||||
"use_embedding_bag": False, # True=用 F.embedding_bag 融合查表+池化(单kernel,免[M,512]中间),攻最大块
|
"use_embedding_bag": True, # F.embedding_bag 融合查表+池化(单kernel,消dedup的unique同步,AUC≈无损)
|
||||||
"dedup_embedding": True, # True=查表前对sign去重(只查唯一值再展开),本地7.80->6.49s,AUC逐位等价
|
"dedup_embedding": True, # True=查表前对sign去重(只查唯一值再展开),本地7.80->6.49s,AUC逐位等价
|
||||||
"sparse_pool": False, # True=用(段×唯一)稀疏矩阵乘做池化,避免materialize整个[M,512](段内高重复时省)
|
"sparse_pool": False, # True=用(段×唯一)稀疏矩阵乘做池化,避免materialize整个[M,512](段内高重复时省)
|
||||||
"compile": False, # 是否 torch.compile(实测慢5×,勿开)
|
"compile": False, # 是否 torch.compile(实测慢5×,勿开)
|
||||||
|
|||||||
Reference in New Issue
Block a user