feat: 真稀疏MoE(capacity分组,只算top-k,cutlass baddbmm,无host同步)
按expert排序token+固定capacity分桶,每桶dense baddbmm,减GEMM~3x。argsort/where/ scatter/index_add无.item()/bincount同步(不同于loop MoE)。超容量token丢弃(capacity_factor控)。 等价测试(大capacity无丢弃==dense)。bench --moe-sparse/--moe-cap。默认关待验证。 Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -347,6 +347,8 @@ def _parse_args():
|
||||
ap.add_argument("--no-moe-baddbmm", action="store_true", help="关闭 MoE baddbmm(用 einsum 对照)")
|
||||
ap.add_argument("--no-skip-moe-loss", action="store_true", help="不跳过 moe_loss(对照)")
|
||||
ap.add_argument("--logit-bias", type=float, default=None, help="PCOC校准:logit偏移(本地验证PCOC→1.0)")
|
||||
ap.add_argument("--moe-sparse", action="store_true", help="真稀疏MoE(只算top-k,capacity分组)")
|
||||
ap.add_argument("--moe-cap", type=float, default=None, help="MoE capacity factor")
|
||||
ap.add_argument("--sparse-pool", action="store_true", help="稀疏矩阵乘做池化(段内高重复时省)")
|
||||
ap.add_argument("--precompute-rep", action="store_true",
|
||||
help="预计算RepEncoder缓存,model(batch)跳过embedding层(从batches自建)")
|
||||
@@ -401,6 +403,10 @@ if __name__ == "__main__":
|
||||
cfg["skip_moe_loss"] = False
|
||||
if a.logit_bias is not None:
|
||||
cfg["logit_bias"] = a.logit_bias
|
||||
if a.moe_sparse:
|
||||
cfg["moe_sparse"] = True
|
||||
if a.moe_cap is not None:
|
||||
cfg["moe_capacity"] = a.moe_cap
|
||||
if a.sparse_pool:
|
||||
cfg["sparse_pool"] = True
|
||||
if a.precompute_rep:
|
||||
|
||||
@@ -144,6 +144,8 @@ CONFIG = {
|
||||
# 须靠提交验证。AUC中性、MoE仅占2%算力故风险极低。
|
||||
"vectorize_moe": True, # True=稠密向量化MoE(无同步点);False=原逐expert循环(.nonzero同步)
|
||||
"moe_baddbmm": True, # MoE FFN 用 baddbmm(cutlass GEMM+bias epilogue融合),省 bias add kernel
|
||||
"moe_sparse": False, # True=真稀疏MoE(只算top-k,capacity分组),减GEMM~3x;风险:开销/容量丢弃AUC
|
||||
"moe_capacity": 1.25, # 每expert容量 = ceil(Nk/E*factor);越大越不丢token但计算越多
|
||||
"skip_moe_loss": True, # 推理跳过 moe_loss(load-balance,推理无用),省 importance/std/mean kernel
|
||||
# PCOC 校准:本地拟合-0.1067(本地PCOC1.109),但评测PCOC稳定1.059,按斜率换算评测最优≈-0.059。
|
||||
"logit_bias": -0.06, # logit 加常数偏移使评测 PCOC→~1.0(单调,AUC不变,免费+~0.33分)
|
||||
@@ -693,6 +695,39 @@ class SMoE(nn.Module):
|
||||
self.register_buffer("W2t", self.W2.transpose(1, 2).contiguous()) # [E,F,D]
|
||||
self._stacked = True
|
||||
|
||||
def _forward_sparse(self, x):
|
||||
"""真稀疏 MoE:每 token 只算 top-k expert(按 expert 排序 + capacity 分桶 + cutlass baddbmm)。
|
||||
全程无 host 同步(argsort/where/scatter/index_add)。超容量 token 被丢弃(capacity_factor 控)。"""
|
||||
import math
|
||||
B, S, D = x.shape
|
||||
topk_idx, topk_score, _ = self.gate(x)
|
||||
N, k, E = B * S, self.k, self.num_experts
|
||||
xf = x.reshape(N, D)
|
||||
flat_e = topk_idx.reshape(-1) # [Nk] 每 pair 的 expert
|
||||
flat_s = topk_score.reshape(-1) # [Nk]
|
||||
Nk = flat_e.numel()
|
||||
flat_t = torch.arange(N, device=x.device).repeat_interleave(k) # [Nk] token id
|
||||
order = torch.argsort(flat_e) # 按 expert 排序(GPU sort,无 host 同步)
|
||||
se, st, ss = flat_e[order], flat_t[order], flat_s[order]
|
||||
xs = xf[st] # [Nk, D]
|
||||
expert_start = torch.searchsorted(se.contiguous(),
|
||||
torch.arange(E, device=x.device)) # [E]
|
||||
pos_within = torch.arange(Nk, device=x.device) - expert_start[se] # 每 token 在其 expert 内位置
|
||||
C = int(math.ceil(Nk / E * CONFIG.get("moe_capacity", 1.25)))
|
||||
valid = pos_within < C
|
||||
slot = se * C + pos_within
|
||||
slot_safe = torch.where(valid, slot, torch.full_like(slot, E * C)) # 超容量→dummy 槽
|
||||
buf = torch.zeros(E * C + 1, D, dtype=xs.dtype, device=x.device)
|
||||
buf[slot_safe] = xs # scatter(dummy 槽不读)
|
||||
h = torch.baddbmm(self.b1.unsqueeze(1), buf[:E * C].view(E, C, D), self.W1t) # [E,C,F]
|
||||
h = F.relu(h)
|
||||
o = torch.baddbmm(self.b2.unsqueeze(1), h, self.W2t) # [E,C,D]
|
||||
o_full = torch.cat([o.reshape(E * C, D),
|
||||
torch.zeros(1, D, dtype=o.dtype, device=x.device)]) # [E*C+1, D]
|
||||
out_s = o_full[slot_safe] * ss.unsqueeze(-1) # [Nk, D](dummy→0)
|
||||
out = torch.zeros(N, D, dtype=x.dtype, device=x.device).index_add_(0, st, out_s)
|
||||
return out.view(B, S, D), out.new_zeros(())
|
||||
|
||||
def forward(self, x):
|
||||
# x: [B,S,D]
|
||||
if not CONFIG.get("vectorize_moe", True):
|
||||
@@ -701,6 +736,9 @@ class SMoE(nn.Module):
|
||||
if not self._stacked:
|
||||
self._stack_weights()
|
||||
|
||||
if CONFIG.get("moe_sparse", False):
|
||||
return self._forward_sparse(x)
|
||||
|
||||
B, S, D = x.shape
|
||||
topk_idx, topk_score, probs = self.gate(x)
|
||||
|
||||
|
||||
@@ -192,6 +192,25 @@ def test_varlen_matches_dense_attention():
|
||||
print(f"[PASS] varlen(嵌套张量) == 稠密SDPA (max err={err:.2e})")
|
||||
|
||||
|
||||
def test_sparse_moe_matches_dense():
|
||||
# 大 capacity(无丢弃)下,稀疏 MoE 应与 dense 数学等价
|
||||
torch.manual_seed(0)
|
||||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
m = infer.SMoE(d_model=512, dim_ff=1024, num_experts=8, k=2).to(dev).eval()
|
||||
x = torch.randn(1, 200, 512, device=dev)
|
||||
with torch.no_grad():
|
||||
infer.CONFIG["moe_sparse"] = False
|
||||
ref, _ = m(x)
|
||||
infer.CONFIG["moe_sparse"] = True
|
||||
infer.CONFIG["moe_capacity"] = 8.0 # 足够大,不丢 token
|
||||
new, _ = m(x)
|
||||
infer.CONFIG["moe_sparse"] = False
|
||||
infer.CONFIG["moe_capacity"] = 1.25
|
||||
err = (ref - new).abs().max().item()
|
||||
assert torch.allclose(ref, new, atol=1e-3, rtol=1e-3), f"sparse MoE 不等价 max err={err:.3e}"
|
||||
print(f"[PASS] sparse MoE(大capacity) == dense (max err={err:.2e}, dev={dev})")
|
||||
|
||||
|
||||
def test_fused_embedding_matches_perslot():
|
||||
torch.manual_seed(0)
|
||||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
@@ -240,6 +259,7 @@ def test_flex_matches_dense_attention():
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_moe_dense_matches_loop()
|
||||
test_sparse_moe_matches_dense()
|
||||
test_fused_embedding_matches_perslot()
|
||||
test_embedding_bag_matches()
|
||||
test_sparse_pool_matches()
|
||||
|
||||
Reference in New Issue
Block a user