diff --git a/代码/code/infer.py b/代码/code/infer.py index 8c4b1d2..4dce083 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -107,15 +107,14 @@ def _triton_varlen_attn(q, k, v, meta): _, H, S, Dh = q.shape cu, blk_seq, blk_inseq, total_blocks = meta BLOCK_M = CONFIG.get("triton_block_m", 64) - # 不强制 contiguous:kernel 用实际 stride 读非连续的 q/k/v(来自 qkv split+permute)。 - # q,k,v split 同源、stride 相同(k,v 含各自 storage_offset,Triton 用其 data_ptr 自动处理)。 + # contiguous 后连续访存更快(实测去 contiguous 用 stride 读反而慢:非连续跨步读 > 一次性 clone)。 out = torch.empty((1, H, S, Dh), device=q.device, dtype=torch.float16) - sqh, sqs, sqd = q.stride(1), q.stride(2), q.stride(3) - soh, sos, sod = out.stride(1), out.stride(2), out.stride(3) + qc = q.contiguous(); kc = k.contiguous(); vc = v.contiguous() + sh, ss, sd = S * Dh, Dh, 1 grid = (total_blocks, H) _varlen_flash_fwd[grid]( - q, k, v, out, cu, blk_seq, blk_inseq, - sqh, sqs, sqd, soh, sos, sod, 1.0 / math.sqrt(Dh), cu.numel() - 1, + qc, kc, vc, out, cu, blk_seq, blk_inseq, + sh, ss, sd, sh, ss, sd, 1.0 / math.sqrt(Dh), cu.numel() - 1, BLOCK_M=BLOCK_M, BLOCK_N=64, D=Dh, ) return out @@ -147,7 +146,7 @@ CONFIG = { "fuse_embedding": True, # True=28个slot的查表+池化融合为1次(减per-batch kernel启动) "syncfree_mask": True, # True=用searchsorted构造因果mask(无同步);False=repeat_interleave(同步) "emb_fp16": True, # True=Embedding表转FP16(查表带宽减半,实测AUC 0.75932≈无损) - "use_embedding_bag": False, # True=用 F.embedding_bag 融合查表+池化(单kernel,免[M,512]中间),攻最大块 + "use_embedding_bag": True, # F.embedding_bag 融合查表+池化(单kernel,消dedup的unique同步,AUC≈无损) "dedup_embedding": True, # True=查表前对sign去重(只查唯一值再展开),本地7.80->6.49s,AUC逐位等价 "sparse_pool": False, # True=用(段×唯一)稀疏矩阵乘做池化,避免materialize整个[M,512](段内高重复时省) "compile": False, # 是否 torch.compile(实测慢5×,勿开)