diff --git a/代码/code/infer.py b/代码/code/infer.py index 18740ec..ae18e08 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -115,16 +115,14 @@ def _triton_varlen_attn(q, k, v, meta): cu, blk_seq, blk_inseq, total_blocks = meta BLOCK_M = CONFIG.get("triton_block_m", 64) # contiguous 后连续访存更快(实测去 contiguous 用 stride 读反而慢:非连续跨步读 > 一次性 clone)。 + # contiguous 输出(实测:为消调用方 clone 改跨步写,评测反而更慢 35.85>34.64,已退回) + out = torch.empty((1, H, S, Dh), device=q.device, dtype=torch.float16) qc = q.contiguous(); kc = k.contiguous(); vc = v.contiguous() sh, ss, sd = S * Dh, Dh, 1 - # out 存储为 [1,S,H,Dh],以 [1,H,S,Dh] 视图返回 → 调用方 permute(0,2,1,3) 得连续、reshape 免 clone - out_storage = torch.empty((1, S, H, Dh), device=q.device, dtype=torch.float16) - out = out_storage.permute(0, 2, 1, 3) # [1,H,S,Dh] 视图 - soh, sos, sod = out.stride(1), out.stride(2), out.stride(3) # = Dh, H*Dh, 1 grid = (total_blocks, H) _varlen_flash_fwd[grid]( qc, kc, vc, out, cu, blk_seq, blk_inseq, - sh, ss, sd, soh, sos, sod, 1.0 / math.sqrt(Dh), cu.numel() - 1, + sh, ss, sd, sh, ss, sd, 1.0 / math.sqrt(Dh), cu.numel() - 1, BLOCK_M=BLOCK_M, BLOCK_N=64, D=Dh, ) return out @@ -154,7 +152,8 @@ CONFIG = { # 须靠提交验证。AUC中性、MoE仅占2%算力故风险极低。 "vectorize_moe": True, # True=稠密向量化MoE(无同步点);False=原逐expert循环(.nonzero同步) "moe_baddbmm": True, # MoE FFN 用 baddbmm(cutlass GEMM+bias epilogue融合),省 bias add kernel - "moe_fused_weight": True, # top-k加权用scatter+mul+sum(在[E,N,D]上),省permute大clone+gather;数学等价 + # 评测净负:scatter+mul+sum 物化[E,N,D]大中间张量(访存)>省的clone。退回 gather 路径。 + "moe_fused_weight": False, # True=top-k加权用scatter+mul+sum(评测慢,勿开) # 真稀疏MoE实测评测净负:lat 34.64->37.64s(本地快15%但argsort/scatter开销评测放大,如varlen) # +容量丢弃降AUC(0.7525->0.7507)。已退回 dense。 "moe_sparse": False, # True=真稀疏MoE(评测净负,勿开)