diff --git a/代码/code/bench.py b/代码/code/bench.py index 272197c..e86967e 100644 --- a/代码/code/bench.py +++ b/代码/code/bench.py @@ -11,9 +11,18 @@ bench.run_once({"fp16": False, "expert_merge": False}) # FP32 参考跑 bench.run_once({"signid_mode": "modulo"}) # 取模 vs clamp """ +import os +import sys import time from pathlib import Path +# baseline 把依赖装在 --target 目录(非默认 site-packages),在 kernel 里 import +# 之前必须先把它加到 sys.path,否则 import torch 会 ModuleNotFoundError。 +for _p in ("/home/aistudio/external-libraries", "/home/aistudio/libraries", + os.path.abspath("../libraries"), os.path.abspath("./libraries")): + if os.path.isdir(_p) and _p not in sys.path: + sys.path.insert(0, _p) + import torch from torch.utils.data import DataLoader