diff --git a/代码/code/infer.py b/代码/code/infer.py index 825b7be..b8a3667 100644 --- a/代码/code/infer.py +++ b/代码/code/infer.py @@ -40,9 +40,10 @@ CONFIG = { "signid_mode": "clamp", # "clamp" 或 "modulo":处理超界 sign id 的方式 "sync_timing": False, # bench 里设 True,做 torch.cuda.synchronize 真实计时 "filter_test_users": True, # 只处理含测试样本的用户(跳过会被丢弃的用户,省算力) - # 实测(A800):sdpa+loop=15.1s 最快;flex/dense/compile/小batch 都更慢。 - # attn: "sdpa"(稠密mask,默认/已验证) / "flex"(FlexAttention,慢) / "varlen"(嵌套张量变长flash) - "attn": "sdpa", + # 实测(A800,本地5451用户):sdpa=15.15s,varlen=10.28s(快32%,AUC不变), + # flex/compile/小batch 都更慢。默认 varlen。 + # attn: "varlen"(嵌套张量变长flash,默认) / "sdpa"(稠密mask) / "flex"(FlexAttention) + "attn": "varlen", "vectorize_moe": False, # True=稠密向量化MoE;False=原逐expert循环(默认,已验证更快) "compile": False, # 是否 torch.compile(实测慢5×,勿开) }