diff --git a/代码/code/bench.py b/代码/code/bench.py index 1cb9427..d922812 100644 --- a/代码/code/bench.py +++ b/代码/code/bench.py @@ -291,8 +291,8 @@ def _parse_args(): help="逗号分隔的 keep_fp32_modules,如 linear,rep_encoder.input_norm") ap.add_argument("--feasign-none", action="store_true", help="不截断特征(max_feasign_per_slot=None)") - ap.add_argument("--attn", choices=["auto", "flex", "sdpa"], default=None, - help="注意力实现:flex=块对角FlexAttention, sdpa=稠密(原), auto=SM80自动") + ap.add_argument("--attn", choices=["sdpa", "flex", "varlen"], default=None, + help="注意力:sdpa=稠密(原), flex=FlexAttention, varlen=嵌套张量变长flash") ap.add_argument("--moe", choices=["dense", "loop"], default=None, help="MoE实现:dense=向量化(新), loop=逐expert循环(原)") ap.add_argument("--compile", action="store_true", help="开启 torch.compile") @@ -322,7 +322,7 @@ if __name__ == "__main__": if a.keep is not None: cfg["keep_fp32_modules"] = tuple(x for x in a.keep.split(",") if x) if a.attn is not None: - cfg["use_flex_attn"] = {"auto": "auto", "flex": True, "sdpa": False}[a.attn] + cfg["attn"] = a.attn if a.moe is not None: cfg["vectorize_moe"] = (a.moe == "dense") if a.compile: diff --git a/代码/code/infer.py b/代码/code/infer.py index ebc7e09..825b7be 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -40,28 +40,27 @@ CONFIG = { "signid_mode": "clamp", # "clamp" 或 "modulo":处理超界 sign id 的方式 "sync_timing": False, # bench 里设 True,做 torch.cuda.synchronize 真实计时 "filter_test_users": True, # 只处理含测试样本的用户(跳过会被丢弃的用户,省算力) - # 实测:FlexAttention + 稠密MoE 在本模型上反而慢 5-6 倍(模型是开销瓶颈非算力瓶颈), - # 故默认回到已验证最快的 sdpa + loop;flex/dense 仅作 bench 对照选项。 - "use_flex_attn": False, # "auto"(SM80+用flex,否则SDPA) / True / False + # 实测(A800):sdpa+loop=15.1s 最快;flex/dense/compile/小batch 都更慢。 + # attn: "sdpa"(稠密mask,默认/已验证) / "flex"(FlexAttention,慢) / "varlen"(嵌套张量变长flash) + "attn": "sdpa", "vectorize_moe": False, # True=稠密向量化MoE;False=原逐expert循环(默认,已验证更快) - "compile": False, # 是否 torch.compile + "compile": False, # 是否 torch.compile(实测慢5×,勿开) } -def _use_flex(device): - """决定是否用 FlexAttention:auto 模式下仅在 SM80+(Ampere/A800)启用。""" - mode = CONFIG.get("use_flex_attn", "auto") - if not _HAS_FLEX or mode is False: - return False - if mode is True: - return True - if device is not None and device.type == "cuda": - try: - major, _ = torch.cuda.get_device_capability(device) - return major >= 8 - except Exception: - return False - return False +def _resolve_attn(device): + """解析实际使用的注意力实现。flex 需 SM80+ 且可用,否则回退 sdpa。""" + attn = CONFIG.get("attn", "sdpa") + if attn == "flex": + if not _HAS_FLEX: + return "sdpa" + if device is not None and device.type == "cuda": + try: + if torch.cuda.get_device_capability(device)[0] < 8: + return "sdpa" + except Exception: + return "sdpa" + return attn def _force_fp32_io(module): @@ -353,12 +352,35 @@ class RepEncoder(nn.Module): return rep_emb +def _varlen_attention(q, k, v, user_offsets): + """嵌套张量变长 flash 注意力:每个用户当独立序列、is_causal 块对角因果。 + 一个内核处理一 batch 内所有用户,无稠密 mask、无 padding 浪费、开销低。 + q,k,v: [1, H, S, Dh];user_offsets: [B+1](S 上的用户边界)。返回 [1, H, S, Dh]。 + """ + _, H, S, Dh = q.shape + offs = user_offsets.to(torch.int64) + # [1,H,S,Dh] -> [S,H,Dh] + qv = q.squeeze(0).transpose(0, 1).contiguous() + kv = k.squeeze(0).transpose(0, 1).contiguous() + vv = v.squeeze(0).transpose(0, 1).contiguous() + # 按用户边界做 jagged 嵌套张量:[B, ragged, H, Dh] -> [B, H, ragged, Dh] + qn = torch.nested.nested_tensor_from_jagged(qv, offsets=offs).transpose(1, 2) + kn = torch.nested.nested_tensor_from_jagged(kv, offsets=offs).transpose(1, 2) + vn = torch.nested.nested_tensor_from_jagged(vv, offsets=offs).transpose(1, 2) + out = F.scaled_dot_product_attention(qn, kn, vn, is_causal=True) # [B,H,ragged,Dh] + out = out.transpose(1, 2).values() # [S, H, Dh] + return out.transpose(0, 1).unsqueeze(0).contiguous() # [1, H, S, Dh] + + def scaled_dot_product(q, k, v, extension): """注意力分发: - - 若 extension 带 block_mask → FlexAttention 块对角因果(每用户只在自己序列内 - 做因果注意力,避免对 ~14000 长拼接序列做 O(S²) 稠密注意力,计算量砍数十倍)。 - - 否则 → 标准 SDPA(稠密 mask,数学等价、用于回退/对照)。 + - varlen_offsets → 嵌套张量变长 flash(每用户独立序列、块对角因果,开销低)。 + - block_mask → FlexAttention 块对角因果。 + - mask(默认) → 标准 SDPA 稠密 mask(数学等价、已验证最快)。 """ + if extension is not None and extension.get("varlen_offsets") is not None: + return _varlen_attention(q, k, v, extension["varlen_offsets"]) + if extension is not None and extension.get("block_mask") is not None: return flex_attention(q, k, v, block_mask=extension["block_mask"]) @@ -562,7 +584,10 @@ class CTRModel(nn.Module): def forward(self, batch): seq_input = self.rep_encoder(batch) user_offsets = batch["user_offsets"] - if _use_flex(seq_input.device): + attn = _resolve_attn(seq_input.device) + if attn == "varlen": + extension = {"varlen_offsets": user_offsets} + elif attn == "flex": S = seq_input.shape[0] # rep_encoder 输出 [S, D],S=总 token 数 extension = {"block_mask": self.build_block_mask(user_offsets, S)} else: @@ -652,8 +677,7 @@ def load_model(ckpt_path, device='cuda:0'): model.to(dev) model.eval() - use_flex = _use_flex(dev) - print(f"[INFO] attention={'FlexAttention(block-causal)' if use_flex else 'SDPA(dense)'}, " + print(f"[INFO] attention={_resolve_attn(dev)}, " f"moe={'dense' if CONFIG.get('vectorize_moe', True) else 'loop'}") if CONFIG.get("compile", False): diff --git a/代码/code/tests/test_equiv.py b/代码/code/tests/test_equiv.py index 522d3e6..5d362fc 100644 --- a/代码/code/tests/test_equiv.py +++ b/代码/code/tests/test_equiv.py @@ -64,11 +64,32 @@ def test_moe_dense_matches_loop(): print(f"[PASS] MoE 稠密向量化 == 逐expert循环 (max err={err:.2e}, dev={dev})") +def test_varlen_matches_dense_attention(): + if not torch.cuda.is_available(): + print("[SKIP] varlen 等价测试(需 CUDA)") + return + torch.manual_seed(0) + dev = "cuda" + H, Dh = 8, 64 + offs = _offsets([10, 25, 7, 40, 18], dev) + S = int(offs[-1]) + q = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16) + k = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16) + v = torch.randn(1, H, S, Dh, device=dev, dtype=torch.float16) + with torch.no_grad(): + dense = infer.scaled_dot_product(q, k, v, {"mask": _dense_causal_mask(offs)[None, None]}) + varlen = infer.scaled_dot_product(q, k, v, {"varlen_offsets": offs}) + err = (dense.float() - varlen.float()).abs().max().item() + assert torch.allclose(dense.float(), varlen.float(), atol=2e-2, rtol=2e-2), \ + f"varlen 不等价 max err={err:.3e}" + print(f"[PASS] varlen(嵌套张量) == 稠密SDPA (max err={err:.2e})") + + def test_flex_matches_dense_attention(): ok = (torch.cuda.is_available() and infer._HAS_FLEX and torch.cuda.get_device_capability()[0] >= 8) if not ok: - print("[SKIP] FlexAttention 等价测试(需 CUDA SM80+,当前环境不满足)") + print("[SKIP] FlexAttention 等价测试(需 CUDA SM80+)") return torch.manual_seed(0) dev = "cuda" @@ -88,5 +109,6 @@ def test_flex_matches_dense_attention(): if __name__ == "__main__": test_moe_dense_matches_loop() + test_varlen_matches_dense_attention() test_flex_matches_dense_attention() print("[DONE] 等价测试结束")