From 69d49cd282d5607ef0872a9229d5743fa51f2a17 Mon Sep 17 00:00:00 2001 From: OwnerSunshine530 Date: Fri, 19 Jun 2026 20:56:27 +0800 Subject: [PATCH] =?UTF-8?q?revert:=20MoE=E5=8A=A0=E6=9D=83+attention?= =?UTF-8?q?=E8=BE=93=E5=87=BA=E5=B8=83=E5=B1=80=E4=B8=A4=E5=88=80(?= =?UTF-8?q?=E8=AF=84=E6=B5=8B=E5=87=80=E8=B4=9F35.85>34.64,=E5=A4=A7?= =?UTF-8?q?=E4=B8=AD=E9=97=B4=E5=BC=A0=E9=87=8F/=E8=B7=A8=E6=AD=A5?= =?UTF-8?q?=E5=86=99=E4=BB=A3=E4=BB=B7>=E7=9C=81=E7=9A=84clone)=E3=80=82?= =?UTF-8?q?=E4=BF=9D=E7=95=99=E6=B6=88=E5=90=8C=E6=AD=A5=E5=88=80=E5=8D=95?= =?UTF-8?q?=E7=8B=AC=E6=B5=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Opus 4.8 --- 代码/code/infer.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/代码/code/infer.py b/代码/code/infer.py index 18740ec..ae18e08 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -115,16 +115,14 @@ def _triton_varlen_attn(q, k, v, meta): cu, blk_seq, blk_inseq, total_blocks = meta BLOCK_M = CONFIG.get("triton_block_m", 64) # contiguous 后连续访存更快(实测去 contiguous 用 stride 读反而慢:非连续跨步读 > 一次性 clone)。 + # contiguous 输出(实测:为消调用方 clone 改跨步写,评测反而更慢 35.85>34.64,已退回) + out = torch.empty((1, H, S, Dh), device=q.device, dtype=torch.float16) qc = q.contiguous(); kc = k.contiguous(); vc = v.contiguous() sh, ss, sd = S * Dh, Dh, 1 - # out 存储为 [1,S,H,Dh],以 [1,H,S,Dh] 视图返回 → 调用方 permute(0,2,1,3) 得连续、reshape 免 clone - out_storage = torch.empty((1, S, H, Dh), device=q.device, dtype=torch.float16) - out = out_storage.permute(0, 2, 1, 3) # [1,H,S,Dh] 视图 - soh, sos, sod = out.stride(1), out.stride(2), out.stride(3) # = Dh, H*Dh, 1 grid = (total_blocks, H) _varlen_flash_fwd[grid]( qc, kc, vc, out, cu, blk_seq, blk_inseq, - sh, ss, sd, soh, sos, sod, 1.0 / math.sqrt(Dh), cu.numel() - 1, + sh, ss, sd, sh, ss, sd, 1.0 / math.sqrt(Dh), cu.numel() - 1, BLOCK_M=BLOCK_M, BLOCK_N=64, D=Dh, ) return out @@ -154,7 +152,8 @@ CONFIG = { # 须靠提交验证。AUC中性、MoE仅占2%算力故风险极低。 "vectorize_moe": True, # True=稠密向量化MoE(无同步点);False=原逐expert循环(.nonzero同步) "moe_baddbmm": True, # MoE FFN 用 baddbmm(cutlass GEMM+bias epilogue融合),省 bias add kernel - "moe_fused_weight": True, # top-k加权用scatter+mul+sum(在[E,N,D]上),省permute大clone+gather;数学等价 + # 评测净负:scatter+mul+sum 物化[E,N,D]大中间张量(访存)>省的clone。退回 gather 路径。 + "moe_fused_weight": False, # True=top-k加权用scatter+mul+sum(评测慢,勿开) # 真稀疏MoE实测评测净负:lat 34.64->37.64s(本地快15%但argsort/scatter开销评测放大,如varlen) # +容量丢弃降AUC(0.7525->0.7507)。已退回 dense。 "moe_sparse": False, # True=真稀疏MoE(评测净负,勿开)