diff --git a/代码/code/infer.py b/代码/code/infer.py index b8a3667..0a8d6e4 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -40,10 +40,10 @@ CONFIG = { "signid_mode": "clamp", # "clamp" 或 "modulo":处理超界 sign id 的方式 "sync_timing": False, # bench 里设 True,做 torch.cuda.synchronize 真实计时 "filter_test_users": True, # 只处理含测试样本的用户(跳过会被丢弃的用户,省算力) - # 实测(A800,本地5451用户):sdpa=15.15s,varlen=10.28s(快32%,AUC不变), - # flex/compile/小batch 都更慢。默认 varlen。 - # attn: "varlen"(嵌套张量变长flash,默认) / "sdpa"(稠密mask) / "flex"(FlexAttention) - "attn": "varlen", + # 实测:varlen 本地快(10.28s)但评测端慢(148s,嵌套张量构造开销随batch数放大)→已退回。 + # sdpa 是评测端验证最快(89.96s/58.86)。flex/compile/小batch/varlen 在评测端都更差。 + # attn: "sdpa"(稠密mask,默认/评测最优) / "varlen"(本地快评测慢) / "flex"(慢) + "attn": "sdpa", "vectorize_moe": False, # True=稠密向量化MoE;False=原逐expert循环(默认,已验证更快) "compile": False, # 是否 torch.compile(实测慢5×,勿开) }