From 264130df0fd44b2d659757e8bc90e494213e97f9 Mon Sep 17 00:00:00 2001 From: OwnerSunshine530 Date: Wed, 17 Jun 2026 20:20:50 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20PCOC=E6=A0=A1=E5=87=86(logit=5Fbias?= =?UTF-8?q?=E5=8D=95=E8=B0=83=E5=81=8F=E7=A7=BB,AUC=E4=B8=8D=E5=8F=98,?= =?UTF-8?q?=E5=85=8D=E8=B4=B9+0.34)=20+=20bench=E8=87=AA=E5=8A=A8=E6=8B=9F?= =?UTF-8?q?=E5=90=88=E5=BB=BA=E8=AE=AEbias?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Opus 4.8 --- 代码/code/bench.py | 24 +++++++++++++++++++++++- 代码/code/infer.py | 4 ++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/代码/code/bench.py b/代码/code/bench.py index 6f4b237..c875fad 100644 --- a/代码/code/bench.py +++ b/代码/code/bench.py @@ -265,6 +265,7 @@ def run_once(config_override=None, batch_size=50, max_batches=None, print(f"[BENCH] rep cache built from batches: {logids.numel()} items") logid2p = {} + logid2logit = {} t_sum = 0.0 with torch.inference_mode(): for b in batches: @@ -278,8 +279,11 @@ def run_once(config_override=None, batch_size=50, max_batches=None, if cuda: torch.cuda.synchronize() t_sum += time.time() - t0 - for lid, p in zip(b["logid"][pm].cpu().tolist(), probs[pm].cpu().tolist()): + lg = logits.squeeze(-1) + for lid, p, lv in zip(b["logid"][pm].cpu().tolist(), + probs[pm].cpu().tolist(), lg[pm].cpu().tolist()): logid2p[lid] = p + logid2logit[lid] = lv order = [int(l.split(",")[0]) for l in open(test_csv) if l.strip()] missing = [lid for lid in order if lid not in logid2p] @@ -297,6 +301,21 @@ def run_once(config_override=None, batch_size=50, max_batches=None, f" -> AUC={res['auc']:.5f} PCOC={res['pcoc']:.4f}" f" lat={res['latency']:.2f}s score={res['score_all']:.2f}" ) + # 拟合 PCOC 校准 logit_bias(使 mean(sigmoid(logit+b))=mean(label)) + try: + ol = np.array([logid2logit.get(lid, 0.0) for lid in order], dtype=np.float64) + labels = infer._read_label(str(label_file)) + ml = float(labels.mean()) + lo, hi = -3.0, 3.0 + for _ in range(60): + mid = 0.5 * (lo + hi) + if (1.0 / (1.0 + np.exp(-(ol + mid)))).mean() > ml: + hi = mid + else: + lo = mid + print(f"[BENCH] 建议 logit_bias={0.5*(lo+hi):.4f}(PCOC→1.0,免费+~0.34分)") + except Exception as e: + print(f"[BENCH] logit_bias 拟合跳过: {e}") return res @@ -327,6 +346,7 @@ def _parse_args(): ap.add_argument("--emb-bag", action="store_true", help="F.embedding_bag 融合查表+池化") ap.add_argument("--no-moe-baddbmm", action="store_true", help="关闭 MoE baddbmm(用 einsum 对照)") ap.add_argument("--no-skip-moe-loss", action="store_true", help="不跳过 moe_loss(对照)") + ap.add_argument("--logit-bias", type=float, default=None, help="PCOC校准:logit偏移(本地验证PCOC→1.0)") ap.add_argument("--sparse-pool", action="store_true", help="稀疏矩阵乘做池化(段内高重复时省)") ap.add_argument("--precompute-rep", action="store_true", help="预计算RepEncoder缓存,model(batch)跳过embedding层(从batches自建)") @@ -379,6 +399,8 @@ if __name__ == "__main__": cfg["moe_baddbmm"] = False if a.no_skip_moe_loss: cfg["skip_moe_loss"] = False + if a.logit_bias is not None: + cfg["logit_bias"] = a.logit_bias if a.sparse_pool: cfg["sparse_pool"] = True if a.precompute_rep: diff --git a/代码/code/infer.py b/代码/code/infer.py index b4e438e..e91bfe4 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -145,6 +145,7 @@ CONFIG = { "vectorize_moe": True, # True=稠密向量化MoE(无同步点);False=原逐expert循环(.nonzero同步) "moe_baddbmm": True, # MoE FFN 用 baddbmm(cutlass GEMM+bias epilogue融合),省 bias add kernel "skip_moe_loss": True, # 推理跳过 moe_loss(load-balance,推理无用),省 importance/std/mean kernel + "logit_bias": 0.0, # PCOC 校准:输出 logit 加常数偏移使 PCOC→1.0(单调变换,AUC不变,免费+~0.34分) "fuse_embedding": True, # True=28个slot的查表+池化融合为1次(减per-batch kernel启动) "syncfree_mask": True, # True=用searchsorted构造因果mask(无同步);False=repeat_interleave(同步) "emb_fp16": True, # True=Embedding表转FP16(查表带宽减半,实测AUC 0.75932≈无损) @@ -877,6 +878,9 @@ class CTRModel(nn.Module): encoder_output, moe_loss = self.seq_encoder(x=seq_input, extension=extension) encoder_output = encoder_output.squeeze(0) pred = self.linear(encoder_output) + bias = CONFIG.get("logit_bias", 0.0) + if bias != 0.0: + pred = pred + bias # PCOC 校准(单调,不改 AUC) pred_logits = torch.clamp(pred, min=-15.0, max=15.0) return pred_logits, moe_loss