diff --git a/代码/code/bench.py b/代码/code/bench.py index c875fad..a800af4 100644 --- a/代码/code/bench.py +++ b/代码/code/bench.py @@ -347,6 +347,8 @@ def _parse_args(): ap.add_argument("--no-moe-baddbmm", action="store_true", help="关闭 MoE baddbmm(用 einsum 对照)") ap.add_argument("--no-skip-moe-loss", action="store_true", help="不跳过 moe_loss(对照)") ap.add_argument("--logit-bias", type=float, default=None, help="PCOC校准:logit偏移(本地验证PCOC→1.0)") + ap.add_argument("--moe-sparse", action="store_true", help="真稀疏MoE(只算top-k,capacity分组)") + ap.add_argument("--moe-cap", type=float, default=None, help="MoE capacity factor") ap.add_argument("--sparse-pool", action="store_true", help="稀疏矩阵乘做池化(段内高重复时省)") ap.add_argument("--precompute-rep", action="store_true", help="预计算RepEncoder缓存,model(batch)跳过embedding层(从batches自建)") @@ -401,6 +403,10 @@ if __name__ == "__main__": cfg["skip_moe_loss"] = False if a.logit_bias is not None: cfg["logit_bias"] = a.logit_bias + if a.moe_sparse: + cfg["moe_sparse"] = True + if a.moe_cap is not None: + cfg["moe_capacity"] = a.moe_cap if a.sparse_pool: cfg["sparse_pool"] = True if a.precompute_rep: diff --git a/代码/code/infer.py b/代码/code/infer.py index 5439561..be23012 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -144,6 +144,8 @@ CONFIG = { # 须靠提交验证。AUC中性、MoE仅占2%算力故风险极低。 "vectorize_moe": True, # True=稠密向量化MoE(无同步点);False=原逐expert循环(.nonzero同步) "moe_baddbmm": True, # MoE FFN 用 baddbmm(cutlass GEMM+bias epilogue融合),省 bias add kernel + "moe_sparse": False, # True=真稀疏MoE(只算top-k,capacity分组),减GEMM~3x;风险:开销/容量丢弃AUC + "moe_capacity": 1.25, # 每expert容量 = ceil(Nk/E*factor);越大越不丢token但计算越多 "skip_moe_loss": True, # 推理跳过 moe_loss(load-balance,推理无用),省 importance/std/mean kernel # PCOC 校准:本地拟合-0.1067(本地PCOC1.109),但评测PCOC稳定1.059,按斜率换算评测最优≈-0.059。 "logit_bias": -0.06, # logit 加常数偏移使评测 PCOC→~1.0(单调,AUC不变,免费+~0.33分) @@ -693,6 +695,39 @@ class SMoE(nn.Module): self.register_buffer("W2t", self.W2.transpose(1, 2).contiguous()) # [E,F,D] self._stacked = True + def _forward_sparse(self, x): + """真稀疏 MoE:每 token 只算 top-k expert(按 expert 排序 + capacity 分桶 + cutlass baddbmm)。 + 全程无 host 同步(argsort/where/scatter/index_add)。超容量 token 被丢弃(capacity_factor 控)。""" + import math + B, S, D = x.shape + topk_idx, topk_score, _ = self.gate(x) + N, k, E = B * S, self.k, self.num_experts + xf = x.reshape(N, D) + flat_e = topk_idx.reshape(-1) # [Nk] 每 pair 的 expert + flat_s = topk_score.reshape(-1) # [Nk] + Nk = flat_e.numel() + flat_t = torch.arange(N, device=x.device).repeat_interleave(k) # [Nk] token id + order = torch.argsort(flat_e) # 按 expert 排序(GPU sort,无 host 同步) + se, st, ss = flat_e[order], flat_t[order], flat_s[order] + xs = xf[st] # [Nk, D] + expert_start = torch.searchsorted(se.contiguous(), + torch.arange(E, device=x.device)) # [E] + pos_within = torch.arange(Nk, device=x.device) - expert_start[se] # 每 token 在其 expert 内位置 + C = int(math.ceil(Nk / E * CONFIG.get("moe_capacity", 1.25))) + valid = pos_within < C + slot = se * C + pos_within + slot_safe = torch.where(valid, slot, torch.full_like(slot, E * C)) # 超容量→dummy 槽 + buf = torch.zeros(E * C + 1, D, dtype=xs.dtype, device=x.device) + buf[slot_safe] = xs # scatter(dummy 槽不读) + h = torch.baddbmm(self.b1.unsqueeze(1), buf[:E * C].view(E, C, D), self.W1t) # [E,C,F] + h = F.relu(h) + o = torch.baddbmm(self.b2.unsqueeze(1), h, self.W2t) # [E,C,D] + o_full = torch.cat([o.reshape(E * C, D), + torch.zeros(1, D, dtype=o.dtype, device=x.device)]) # [E*C+1, D] + out_s = o_full[slot_safe] * ss.unsqueeze(-1) # [Nk, D](dummy→0) + out = torch.zeros(N, D, dtype=x.dtype, device=x.device).index_add_(0, st, out_s) + return out.view(B, S, D), out.new_zeros(()) + def forward(self, x): # x: [B,S,D] if not CONFIG.get("vectorize_moe", True): @@ -701,6 +736,9 @@ class SMoE(nn.Module): if not self._stacked: self._stack_weights() + if CONFIG.get("moe_sparse", False): + return self._forward_sparse(x) + B, S, D = x.shape topk_idx, topk_score, probs = self.gate(x) diff --git a/代码/code/tests/test_equiv.py b/代码/code/tests/test_equiv.py index 49a81b8..53e6407 100644 --- a/代码/code/tests/test_equiv.py +++ b/代码/code/tests/test_equiv.py @@ -192,6 +192,25 @@ def test_varlen_matches_dense_attention(): print(f"[PASS] varlen(嵌套张量) == 稠密SDPA (max err={err:.2e})") +def test_sparse_moe_matches_dense(): + # 大 capacity(无丢弃)下,稀疏 MoE 应与 dense 数学等价 + torch.manual_seed(0) + dev = "cuda" if torch.cuda.is_available() else "cpu" + m = infer.SMoE(d_model=512, dim_ff=1024, num_experts=8, k=2).to(dev).eval() + x = torch.randn(1, 200, 512, device=dev) + with torch.no_grad(): + infer.CONFIG["moe_sparse"] = False + ref, _ = m(x) + infer.CONFIG["moe_sparse"] = True + infer.CONFIG["moe_capacity"] = 8.0 # 足够大,不丢 token + new, _ = m(x) + infer.CONFIG["moe_sparse"] = False + infer.CONFIG["moe_capacity"] = 1.25 + err = (ref - new).abs().max().item() + assert torch.allclose(ref, new, atol=1e-3, rtol=1e-3), f"sparse MoE 不等价 max err={err:.3e}" + print(f"[PASS] sparse MoE(大capacity) == dense (max err={err:.2e}, dev={dev})") + + def test_fused_embedding_matches_perslot(): torch.manual_seed(0) dev = "cuda" if torch.cuda.is_available() else "cpu" @@ -240,6 +259,7 @@ def test_flex_matches_dense_attention(): if __name__ == "__main__": test_moe_dense_matches_loop() + test_sparse_moe_matches_dense() test_fused_embedding_matches_perslot() test_embedding_bag_matches() test_sparse_pool_matches()