|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import builtins |
|
import json |
|
import os |
|
import pickle |
|
import time |
|
|
|
import ipdb |
|
import torch |
|
import torch.distributed as dist |
|
from triton.runtime.autotuner import Autotuner |
|
|
|
|
|
class CustomAutotuner(Autotuner): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.best_config_cache_path = os.path.expanduser( |
|
os.path.join( |
|
"~", |
|
".triton", |
|
"best_config_cache", |
|
torch.cuda.get_device_name(0).replace(" ", "_"), |
|
self.base_fn.__name__ + ".pkl", |
|
) |
|
) |
|
if os.path.exists(self.best_config_cache_path): |
|
with open(self.best_config_cache_path, "rb") as f: |
|
self.cache = pickle.load(f) |
|
|
|
def run(self, *args, **kwargs): |
|
self.nargs = dict(zip(self.arg_names, args)) |
|
used_cached_result = True |
|
if len(self.configs) > 1: |
|
all_args = {**self.nargs, **kwargs} |
|
_args = [] |
|
for name in self.arg_names: |
|
if name in all_args: |
|
_args.append(all_args[name]) |
|
key = [_args[i] for i in self.key_idx] |
|
for arg in _args: |
|
if hasattr(arg, "dtype"): |
|
key.append(str(arg.dtype)) |
|
key = tuple(key) |
|
if key not in self.cache: |
|
|
|
used_cached_result = False |
|
pruned_configs = self.prune_configs(kwargs) |
|
bench_start = time.time() |
|
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} |
|
bench_end = time.time() |
|
self.bench_time = bench_end - bench_start |
|
self.cache[key] = builtins.min(timings, key=timings.get) |
|
self.pre_hook(args, reset_only=True) |
|
self.configs_timings = timings |
|
if not dist.is_initialized() or dist.get_rank() == 0: |
|
best_config_cache_dir = os.path.dirname(self.best_config_cache_path) |
|
os.makedirs(best_config_cache_dir, exist_ok=True) |
|
with open(self.best_config_cache_path, "wb") as f: |
|
pickle.dump(self.cache, f) |
|
config = self.cache[key] |
|
else: |
|
config = self.configs[0] |
|
self.best_config = config |
|
if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" and not used_cached_result: |
|
print( |
|
f"Triton autotuning for function {self.base_fn.__name__} finished after " |
|
f"{self.bench_time:.2f}s; best config selected: {self.best_config};" |
|
) |
|
if config.pre_hook is not None: |
|
config.pre_hook({**self.nargs, **kwargs, **config.all_kwargs()}) |
|
ret = self.fn.run( |
|
*args, |
|
**kwargs, |
|
**config.all_kwargs(), |
|
) |
|
self.nargs = None |
|
return ret |
|
|
|
|
|
def custom_autotune( |
|
configs, |
|
key, |
|
prune_configs_by=None, |
|
reset_to_zero=None, |
|
restore_value=None, |
|
pre_hook=None, |
|
post_hook=None, |
|
warmup=25, |
|
rep=100, |
|
use_cuda_graph=False, |
|
): |
|
def decorator(fn): |
|
return CustomAutotuner( |
|
fn, |
|
fn.arg_names, |
|
configs, |
|
key, |
|
reset_to_zero, |
|
restore_value, |
|
pre_hook=pre_hook, |
|
post_hook=post_hook, |
|
prune_configs_by=prune_configs_by, |
|
warmup=warmup, |
|
rep=rep, |
|
use_cuda_graph=use_cuda_graph, |
|
) |
|
|
|
return decorator |
|
|