feat/auc-recovery-plan #1
+41
-13
@@ -48,6 +48,7 @@ CONFIG = {
|
|||||||
# synchronize,去掉同步点可能让被计时的 model(batch) 大幅缩短。本地force-sync看不出,
|
# synchronize,去掉同步点可能让被计时的 model(batch) 大幅缩短。本地force-sync看不出,
|
||||||
# 须靠提交验证。AUC中性、MoE仅占2%算力故风险极低。
|
# 须靠提交验证。AUC中性、MoE仅占2%算力故风险极低。
|
||||||
"vectorize_moe": True, # True=稠密向量化MoE(无同步点);False=原逐expert循环(.nonzero同步)
|
"vectorize_moe": True, # True=稠密向量化MoE(无同步点);False=原逐expert循环(.nonzero同步)
|
||||||
|
"fuse_embedding": True, # True=28个slot的查表+池化融合为1次(减per-batch kernel启动)
|
||||||
"compile": False, # 是否 torch.compile(实测慢5×,勿开)
|
"compile": False, # 是否 torch.compile(实测慢5×,勿开)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -327,6 +328,22 @@ def move_batch_to_device(batch, device):
|
|||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def _rep_forward_perslot(enc, batch):
|
||||||
|
"""原始逐 slot 实现(保留作数值等价对照/回退)。"""
|
||||||
|
pooled_embs = []
|
||||||
|
max_idx = enc.emb.num_embeddings - 1
|
||||||
|
target_dtype = enc.input_norm.weight.dtype
|
||||||
|
for i in range(enc.slot_num):
|
||||||
|
values, offsets = batch[i + 1]
|
||||||
|
offsets = offsets.to(values.device)
|
||||||
|
values = enc._signid(values, max_idx)
|
||||||
|
sign_emb = enc.emb(values).to(target_dtype)
|
||||||
|
res = torch.segment_reduce(sign_emb, reduce='sum', offsets=offsets, initial=0)
|
||||||
|
pooled_embs.append(res)
|
||||||
|
fused_embs = torch.cat(pooled_embs, dim=1)
|
||||||
|
return enc.linear(enc.input_norm(fused_embs))
|
||||||
|
|
||||||
|
|
||||||
class RepEncoder(nn.Module):
|
class RepEncoder(nn.Module):
|
||||||
def __init__(self, vocab_size, emb_dim, padding_idx=0, slot_num=0, d_model=0):
|
def __init__(self, vocab_size, emb_dim, padding_idx=0, slot_num=0, d_model=0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -336,24 +353,35 @@ class RepEncoder(nn.Module):
|
|||||||
self.input_norm = nn.LayerNorm(slot_num * emb_dim)
|
self.input_norm = nn.LayerNorm(slot_num * emb_dim)
|
||||||
self.linear = nn.Linear(in_features=slot_num * emb_dim, out_features=d_model)
|
self.linear = nn.Linear(in_features=slot_num * emb_dim, out_features=d_model)
|
||||||
|
|
||||||
|
def _signid(self, values, max_idx):
|
||||||
|
if CONFIG["signid_mode"] == "modulo":
|
||||||
|
return values % self.emb.num_embeddings # 取模哈希(与训练一致时用)
|
||||||
|
return values.clamp(0, max_idx) # 超界 sign id 截断
|
||||||
|
|
||||||
def forward(self, batch):
|
def forward(self, batch):
|
||||||
pooled_embs = []
|
if not CONFIG.get("fuse_embedding", True):
|
||||||
|
return _rep_forward_perslot(self, batch)
|
||||||
|
|
||||||
max_idx = self.emb.num_embeddings - 1
|
max_idx = self.emb.num_embeddings - 1
|
||||||
target_dtype = self.input_norm.weight.dtype # 后续层 dtype(FP16 时为 torch.float16)
|
target_dtype = self.input_norm.weight.dtype
|
||||||
|
N = batch[1][1].numel() - 1 # 样本数(slot1 的 offsets 段数)
|
||||||
|
|
||||||
|
# 把 28 个 slot 的 values 拼成一条,offsets 平移拼成覆盖 28*N 段的单一 offsets
|
||||||
|
parts, ends, base = [], [], 0
|
||||||
for i in range(self.slot_num):
|
for i in range(self.slot_num):
|
||||||
values, offsets = batch[i + 1]
|
values, offsets = batch[i + 1]
|
||||||
offsets = offsets.to(values.device)
|
offsets = offsets.to(values.device)
|
||||||
if CONFIG["signid_mode"] == "modulo":
|
parts.append(values)
|
||||||
values = values % self.emb.num_embeddings # 取模哈希(与训练一致时用)
|
ends.append(offsets[1:] + base) # 该 slot 各样本的段尾(平移 base)
|
||||||
else:
|
base += values.numel() # numel 读 shape,不触发同步
|
||||||
values = values.clamp(0, max_idx) # 超出 vocab_size 的 sign id 截断,避免越界
|
cat_values = self._signid(torch.cat(parts), max_idx)
|
||||||
sign_emb = self.emb(values).to(target_dtype)
|
seg = torch.cat([torch.zeros(1, dtype=torch.long, device=cat_values.device),
|
||||||
res = torch.segment_reduce(sign_emb, reduce='sum', offsets=offsets, initial=0)
|
torch.cat(ends)]) # [28*N + 1]
|
||||||
pooled_embs.append(res)
|
emb = self.emb(cat_values).to(target_dtype)
|
||||||
fused_embs = torch.cat(pooled_embs, dim=1)
|
pooled = torch.segment_reduce(emb, reduce='sum', offsets=seg, initial=0) # [28*N, emb]
|
||||||
norm_emb = self.input_norm(fused_embs)
|
pooled = pooled.view(self.slot_num, N, self.emb_dim).permute(1, 0, 2).reshape(
|
||||||
rep_emb = self.linear(norm_emb)
|
N, self.slot_num * self.emb_dim)
|
||||||
return rep_emb
|
return self.linear(self.input_norm(pooled))
|
||||||
|
|
||||||
|
|
||||||
def _varlen_attention(q, k, v, user_offsets):
|
def _varlen_attention(q, k, v, user_offsets):
|
||||||
|
|||||||
@@ -85,6 +85,30 @@ def test_varlen_matches_dense_attention():
|
|||||||
print(f"[PASS] varlen(嵌套张量) == 稠密SDPA (max err={err:.2e})")
|
print(f"[PASS] varlen(嵌套张量) == 稠密SDPA (max err={err:.2e})")
|
||||||
|
|
||||||
|
|
||||||
|
def test_fused_embedding_matches_perslot():
|
||||||
|
torch.manual_seed(0)
|
||||||
|
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
slot_num, emb_dim, d_model = 28, 512, 512
|
||||||
|
enc = infer.RepEncoder(vocab_size=10000, emb_dim=emb_dim, slot_num=slot_num,
|
||||||
|
d_model=d_model).to(dev).eval()
|
||||||
|
# 造一个 N=6 样本的 batch:每 slot 每样本 0~4 个 sign(含空 slot 边界)
|
||||||
|
N = 6
|
||||||
|
batch = {}
|
||||||
|
for s in range(1, slot_num + 1):
|
||||||
|
counts = torch.randint(0, 5, (N,))
|
||||||
|
vals = torch.randint(0, 10000, (int(counts.sum()),), device=dev)
|
||||||
|
offs = torch.cat([torch.zeros(1, dtype=torch.long), counts.cumsum(0)]).to(dev)
|
||||||
|
batch[s] = (vals, offs)
|
||||||
|
with torch.no_grad():
|
||||||
|
infer.CONFIG["fuse_embedding"] = False
|
||||||
|
ref = enc(batch)
|
||||||
|
infer.CONFIG["fuse_embedding"] = True
|
||||||
|
new = enc(batch)
|
||||||
|
err = (ref - new).abs().max().item()
|
||||||
|
assert torch.allclose(ref, new, atol=1e-4, rtol=1e-4), f"embedding融合不等价 max err={err:.3e}"
|
||||||
|
print(f"[PASS] embedding 融合 == 逐slot (max err={err:.2e}, dev={dev})")
|
||||||
|
|
||||||
|
|
||||||
def test_flex_matches_dense_attention():
|
def test_flex_matches_dense_attention():
|
||||||
ok = (torch.cuda.is_available() and infer._HAS_FLEX
|
ok = (torch.cuda.is_available() and infer._HAS_FLEX
|
||||||
and torch.cuda.get_device_capability()[0] >= 8)
|
and torch.cuda.get_device_capability()[0] >= 8)
|
||||||
@@ -109,6 +133,7 @@ def test_flex_matches_dense_attention():
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_moe_dense_matches_loop()
|
test_moe_dense_matches_loop()
|
||||||
|
test_fused_embedding_matches_perslot()
|
||||||
test_varlen_matches_dense_attention()
|
test_varlen_matches_dense_attention()
|
||||||
test_flex_matches_dense_attention()
|
test_flex_matches_dense_attention()
|
||||||
print("[DONE] 等价测试结束")
|
print("[DONE] 等价测试结束")
|
||||||
|
|||||||
Reference in New Issue
Block a user