feat: 真稀疏MoE(capacity分组,只算top-k,cutlass baddbmm,无host同步)
按expert排序token+固定capacity分桶,每桶dense baddbmm,减GEMM~3x。argsort/where/ scatter/index_add无.item()/bincount同步(不同于loop MoE)。超容量token丢弃(capacity_factor控)。 等价测试(大capacity无丢弃==dense)。bench --moe-sparse/--moe-cap。默认关待验证。 Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -144,6 +144,8 @@ CONFIG = {
|
||||
# 须靠提交验证。AUC中性、MoE仅占2%算力故风险极低。
|
||||
"vectorize_moe": True, # True=稠密向量化MoE(无同步点);False=原逐expert循环(.nonzero同步)
|
||||
"moe_baddbmm": True, # MoE FFN 用 baddbmm(cutlass GEMM+bias epilogue融合),省 bias add kernel
|
||||
"moe_sparse": False, # True=真稀疏MoE(只算top-k,capacity分组),减GEMM~3x;风险:开销/容量丢弃AUC
|
||||
"moe_capacity": 1.25, # 每expert容量 = ceil(Nk/E*factor);越大越不丢token但计算越多
|
||||
"skip_moe_loss": True, # 推理跳过 moe_loss(load-balance,推理无用),省 importance/std/mean kernel
|
||||
# PCOC 校准:本地拟合-0.1067(本地PCOC1.109),但评测PCOC稳定1.059,按斜率换算评测最优≈-0.059。
|
||||
"logit_bias": -0.06, # logit 加常数偏移使评测 PCOC→~1.0(单调,AUC不变,免费+~0.33分)
|
||||
@@ -693,6 +695,39 @@ class SMoE(nn.Module):
|
||||
self.register_buffer("W2t", self.W2.transpose(1, 2).contiguous()) # [E,F,D]
|
||||
self._stacked = True
|
||||
|
||||
def _forward_sparse(self, x):
|
||||
"""真稀疏 MoE:每 token 只算 top-k expert(按 expert 排序 + capacity 分桶 + cutlass baddbmm)。
|
||||
全程无 host 同步(argsort/where/scatter/index_add)。超容量 token 被丢弃(capacity_factor 控)。"""
|
||||
import math
|
||||
B, S, D = x.shape
|
||||
topk_idx, topk_score, _ = self.gate(x)
|
||||
N, k, E = B * S, self.k, self.num_experts
|
||||
xf = x.reshape(N, D)
|
||||
flat_e = topk_idx.reshape(-1) # [Nk] 每 pair 的 expert
|
||||
flat_s = topk_score.reshape(-1) # [Nk]
|
||||
Nk = flat_e.numel()
|
||||
flat_t = torch.arange(N, device=x.device).repeat_interleave(k) # [Nk] token id
|
||||
order = torch.argsort(flat_e) # 按 expert 排序(GPU sort,无 host 同步)
|
||||
se, st, ss = flat_e[order], flat_t[order], flat_s[order]
|
||||
xs = xf[st] # [Nk, D]
|
||||
expert_start = torch.searchsorted(se.contiguous(),
|
||||
torch.arange(E, device=x.device)) # [E]
|
||||
pos_within = torch.arange(Nk, device=x.device) - expert_start[se] # 每 token 在其 expert 内位置
|
||||
C = int(math.ceil(Nk / E * CONFIG.get("moe_capacity", 1.25)))
|
||||
valid = pos_within < C
|
||||
slot = se * C + pos_within
|
||||
slot_safe = torch.where(valid, slot, torch.full_like(slot, E * C)) # 超容量→dummy 槽
|
||||
buf = torch.zeros(E * C + 1, D, dtype=xs.dtype, device=x.device)
|
||||
buf[slot_safe] = xs # scatter(dummy 槽不读)
|
||||
h = torch.baddbmm(self.b1.unsqueeze(1), buf[:E * C].view(E, C, D), self.W1t) # [E,C,F]
|
||||
h = F.relu(h)
|
||||
o = torch.baddbmm(self.b2.unsqueeze(1), h, self.W2t) # [E,C,D]
|
||||
o_full = torch.cat([o.reshape(E * C, D),
|
||||
torch.zeros(1, D, dtype=o.dtype, device=x.device)]) # [E*C+1, D]
|
||||
out_s = o_full[slot_safe] * ss.unsqueeze(-1) # [Nk, D](dummy→0)
|
||||
out = torch.zeros(N, D, dtype=x.dtype, device=x.device).index_add_(0, st, out_s)
|
||||
return out.view(B, S, D), out.new_zeros(())
|
||||
|
||||
def forward(self, x):
|
||||
# x: [B,S,D]
|
||||
if not CONFIG.get("vectorize_moe", True):
|
||||
@@ -701,6 +736,9 @@ class SMoE(nn.Module):
|
||||
if not self._stacked:
|
||||
self._stack_weights()
|
||||
|
||||
if CONFIG.get("moe_sparse", False):
|
||||
return self._forward_sparse(x)
|
||||
|
||||
B, S, D = x.shape
|
||||
topk_idx, topk_score, probs = self.gate(x)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user