fix: Phase B 实测回归(flex+dense慢5-6x),默认回退 sdpa+loop;bench 加 --profile

实测 A800:sdpa+loop=15.15s,flex+dense=98s,+compile=82s。模型是开销瓶颈
非算力瓶颈(30TFLOP应0.15s却跑15s),FlexAttention解决的算力问题非此处瓶颈、
反增开销。默认改回已验证最快的 sdpa+loop。新增 bench --profile 用 torch.profiler
定位真正的开销来源(算子级)。

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
OwnerSunshine530
2026-06-15 00:25:53 +08:00
parent c1d8b91fb2
commit 9eaf5f5511
2 changed files with 51 additions and 3 deletions
+46
View File
@@ -160,6 +160,47 @@ def run_diag(rebuild=False):
f"超界sign占比={over}/{tot}={(over / max(tot, 1)):.2%}")
def run_profile(config_override=None, n=20, batch_size=50, rebuild=False):
"""用 torch.profiler 剖析前 n 个 batch,打印按 CUDA 耗时排序的算子表,定位真正瓶颈。"""
if config_override is None:
config_override = {}
infer.CONFIG.update(config_override)
cur = Path(__file__).parent
ref = cur / "dataset"
item_dict, user_seq = _get_data(cur, ref, rebuild=rebuild)
test_logids = infer.load_logids_from_file(ref / "test.csv")
ds = infer.CTRTestSeqDataset(
test_logids_ordered=list(test_logids), item_dict=item_dict,
user_seq=user_seq, max_feasign_per_slot={1: 2}, max_ctx_len=None)
loader = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0,
collate_fn=infer.make_collate_fn(ds.max_slot_id))
batches = []
for b in loader:
batches.append(infer.move_batch_to_device(b, torch.device("cpu")))
if len(batches) >= n:
break
del item_dict, user_seq, ds, loader
import gc
gc.collect()
model, dev = infer.load_model(ckpt_path=None)
cuda = (dev.type == "cuda")
from torch.profiler import profile, ProfilerActivity
acts = [ProfilerActivity.CPU] + ([ProfilerActivity.CUDA] if cuda else [])
with torch.inference_mode():
warm = infer.move_batch_to_device(batches[0], dev) # 预热(触发任何首次编译)
model(warm)
if cuda:
torch.cuda.synchronize()
with profile(activities=acts) as prof:
for b in batches:
b = infer.move_batch_to_device(b, dev)
model(b)
if cuda:
torch.cuda.synchronize()
sort_key = "cuda_time_total" if cuda else "cpu_time_total"
print(prof.key_averages().table(sort_by=sort_key, row_limit=25))
def run_once(config_override=None, batch_size=50, max_batches=None,
max_feasign_per_slot=None, rebuild=False):
"""跑一次本地推理并打分。返回 infer._cal_score 的结果 dict。"""
@@ -255,6 +296,8 @@ def _parse_args():
ap.add_argument("--moe", choices=["dense", "loop"], default=None,
help="MoE实现:dense=向量化(新), loop=逐expert循环(原)")
ap.add_argument("--compile", action="store_true", help="开启 torch.compile")
ap.add_argument("--profile", type=int, default=None, metavar="N",
help="剖析前 N 个 batch,打印按 CUDA 耗时排序的算子表(定位瓶颈)")
ap.add_argument("--rebuild", action="store_true", help="强制重建过滤缓存")
return ap.parse_args()
@@ -284,5 +327,8 @@ if __name__ == "__main__":
cfg["vectorize_moe"] = (a.moe == "dense")
if a.compile:
cfg["compile"] = True
if a.profile is not None:
run_profile(cfg, n=a.profile, batch_size=a.bs, rebuild=a.rebuild)
sys.exit(0)
mf = None if a.feasign_none else {1: 2}
run_once(cfg, batch_size=a.bs, max_batches=a.smoke, max_feasign_per_slot=mf, rebuild=a.rebuild)
+5 -3
View File
@@ -40,9 +40,11 @@ CONFIG = {
"signid_mode": "clamp", # "clamp" 或 "modulo":处理超界 sign id 的方式
"sync_timing": False, # bench 里设 True,做 torch.cuda.synchronize 真实计时
"filter_test_users": True, # 只处理含测试样本的用户(跳过会被丢弃的用户,省算力)
"use_flex_attn": "auto", # "auto"(SM80+用flex,否则SDPA) / True / False
"vectorize_moe": True, # True=稠密向量化MoE(无Python循环/同步);False=原逐expert循环
"compile": False, # 是否 torch.compile(图理干净后再开)
# 实测:FlexAttention + 稠密MoE 在本模型上反而慢 5-6 倍(模型是开销瓶颈非算力瓶颈),
# 故默认回到已验证最快的 sdpa + loopflex/dense 仅作 bench 对照选项。
"use_flex_attn": False, # "auto"(SM80+用flex,否则SDPA) / True / False
"vectorize_moe": False, # True=稠密向量化MoEFalse=原逐expert循环(默认,已验证更快)
"compile": False, # 是否 torch.compile
}