fix: Phase B 实测回归(flex+dense慢5-6x),默认回退 sdpa+loop;bench 加 --profile
实测 A800:sdpa+loop=15.15s,flex+dense=98s,+compile=82s。模型是开销瓶颈 非算力瓶颈(30TFLOP应0.15s却跑15s),FlexAttention解决的算力问题非此处瓶颈、 反增开销。默认改回已验证最快的 sdpa+loop。新增 bench --profile 用 torch.profiler 定位真正的开销来源(算子级)。 Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -160,6 +160,47 @@ def run_diag(rebuild=False):
|
||||
f"超界sign占比={over}/{tot}={(over / max(tot, 1)):.2%}")
|
||||
|
||||
|
||||
def run_profile(config_override=None, n=20, batch_size=50, rebuild=False):
|
||||
"""用 torch.profiler 剖析前 n 个 batch,打印按 CUDA 耗时排序的算子表,定位真正瓶颈。"""
|
||||
if config_override is None:
|
||||
config_override = {}
|
||||
infer.CONFIG.update(config_override)
|
||||
cur = Path(__file__).parent
|
||||
ref = cur / "dataset"
|
||||
item_dict, user_seq = _get_data(cur, ref, rebuild=rebuild)
|
||||
test_logids = infer.load_logids_from_file(ref / "test.csv")
|
||||
ds = infer.CTRTestSeqDataset(
|
||||
test_logids_ordered=list(test_logids), item_dict=item_dict,
|
||||
user_seq=user_seq, max_feasign_per_slot={1: 2}, max_ctx_len=None)
|
||||
loader = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0,
|
||||
collate_fn=infer.make_collate_fn(ds.max_slot_id))
|
||||
batches = []
|
||||
for b in loader:
|
||||
batches.append(infer.move_batch_to_device(b, torch.device("cpu")))
|
||||
if len(batches) >= n:
|
||||
break
|
||||
del item_dict, user_seq, ds, loader
|
||||
import gc
|
||||
gc.collect()
|
||||
model, dev = infer.load_model(ckpt_path=None)
|
||||
cuda = (dev.type == "cuda")
|
||||
from torch.profiler import profile, ProfilerActivity
|
||||
acts = [ProfilerActivity.CPU] + ([ProfilerActivity.CUDA] if cuda else [])
|
||||
with torch.inference_mode():
|
||||
warm = infer.move_batch_to_device(batches[0], dev) # 预热(触发任何首次编译)
|
||||
model(warm)
|
||||
if cuda:
|
||||
torch.cuda.synchronize()
|
||||
with profile(activities=acts) as prof:
|
||||
for b in batches:
|
||||
b = infer.move_batch_to_device(b, dev)
|
||||
model(b)
|
||||
if cuda:
|
||||
torch.cuda.synchronize()
|
||||
sort_key = "cuda_time_total" if cuda else "cpu_time_total"
|
||||
print(prof.key_averages().table(sort_by=sort_key, row_limit=25))
|
||||
|
||||
|
||||
def run_once(config_override=None, batch_size=50, max_batches=None,
|
||||
max_feasign_per_slot=None, rebuild=False):
|
||||
"""跑一次本地推理并打分。返回 infer._cal_score 的结果 dict。"""
|
||||
@@ -255,6 +296,8 @@ def _parse_args():
|
||||
ap.add_argument("--moe", choices=["dense", "loop"], default=None,
|
||||
help="MoE实现:dense=向量化(新), loop=逐expert循环(原)")
|
||||
ap.add_argument("--compile", action="store_true", help="开启 torch.compile")
|
||||
ap.add_argument("--profile", type=int, default=None, metavar="N",
|
||||
help="剖析前 N 个 batch,打印按 CUDA 耗时排序的算子表(定位瓶颈)")
|
||||
ap.add_argument("--rebuild", action="store_true", help="强制重建过滤缓存")
|
||||
return ap.parse_args()
|
||||
|
||||
@@ -284,5 +327,8 @@ if __name__ == "__main__":
|
||||
cfg["vectorize_moe"] = (a.moe == "dense")
|
||||
if a.compile:
|
||||
cfg["compile"] = True
|
||||
if a.profile is not None:
|
||||
run_profile(cfg, n=a.profile, batch_size=a.bs, rebuild=a.rebuild)
|
||||
sys.exit(0)
|
||||
mf = None if a.feasign_none else {1: 2}
|
||||
run_once(cfg, batch_size=a.bs, max_batches=a.smoke, max_feasign_per_slot=mf, rebuild=a.rebuild)
|
||||
|
||||
+5
-3
@@ -40,9 +40,11 @@ CONFIG = {
|
||||
"signid_mode": "clamp", # "clamp" 或 "modulo":处理超界 sign id 的方式
|
||||
"sync_timing": False, # bench 里设 True,做 torch.cuda.synchronize 真实计时
|
||||
"filter_test_users": True, # 只处理含测试样本的用户(跳过会被丢弃的用户,省算力)
|
||||
"use_flex_attn": "auto", # "auto"(SM80+用flex,否则SDPA) / True / False
|
||||
"vectorize_moe": True, # True=稠密向量化MoE(无Python循环/同步);False=原逐expert循环
|
||||
"compile": False, # 是否 torch.compile(图理干净后再开)
|
||||
# 实测:FlexAttention + 稠密MoE 在本模型上反而慢 5-6 倍(模型是开销瓶颈非算力瓶颈),
|
||||
# 故默认回到已验证最快的 sdpa + loop;flex/dense 仅作 bench 对照选项。
|
||||
"use_flex_attn": False, # "auto"(SM80+用flex,否则SDPA) / True / False
|
||||
"vectorize_moe": False, # True=稠密向量化MoE;False=原逐expert循环(默认,已验证更快)
|
||||
"compile": False, # 是否 torch.compile
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user