Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
import glob | |
import os | |
import shlex | |
import subprocess | |
import sys | |
import torch | |
import xformers | |
# Build failed - return early | |
if not xformers._has_cpp_library: | |
print("xFormers wasn't built correctly - can't run benchmarks") | |
sys.exit(0) | |
benchmark_script = os.path.join("xformers", "benchmarks", sys.argv[1]) | |
benchmark_fn = sys.argv[2] | |
label = subprocess.check_output(["git", "rev-parse", "HEAD"], text=True).strip()[:8] | |
cmd = [ | |
sys.executable, | |
benchmark_script, | |
"--label", | |
label, | |
"--fn", | |
benchmark_fn, | |
"--fail_if_regression", | |
"--quiet", | |
] | |
env = ( | |
torch.cuda.get_device_name(torch.cuda.current_device()) | |
.replace(" ", "_") | |
.replace("-", "_") | |
.replace(".", "_") | |
) | |
# Figure out the name of the baseline | |
pattern = os.path.join(os.environ["XFORMERS_BENCHMARKS_CACHE"], benchmark_fn, "*.csv") | |
ref_names = glob.glob(pattern) | |
baseline_names = set( | |
os.path.basename(s)[: -len(".csv")] | |
for s in ref_names | |
# Only compare to benchmark data on same hardware | |
if env in os.path.basename(s) | |
) | |
if baseline_names: | |
if len(baseline_names) > 1: | |
raise RuntimeError( | |
f"Supplied more than one reference for this benchmark: {','.join(baseline_names)}" | |
) | |
cmd += ["--compare", ",".join(baseline_names)] | |
print("EXEC:", shlex.join(cmd)) | |
retcode = 0 | |
try: | |
subprocess.check_call(cmd) | |
except subprocess.CalledProcessError as e: | |
retcode = e.returncode | |
# Remove original benchmark files | |
for f in ref_names: | |
os.remove(f) | |
# Rename new ones as 'ref' | |
for f in glob.glob(pattern): | |
os.rename(f, f.replace(label, "reference")) | |
sys.exit(retcode) | |