|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import functools |
|
import os |
|
import sys |
|
from typing import Any, Callable, Optional |
|
|
|
from hydra._internal.utils import _run_hydra, get_args_parser |
|
from hydra.core.config_store import ConfigStore |
|
from hydra.types import TaskFunction |
|
from omegaconf import DictConfig, OmegaConf |
|
|
|
|
|
OmegaConf.register_new_resolver("multiply", lambda x, y: x * y) |
|
|
|
|
|
def hydra_runner( |
|
config_path: Optional[str] = ".", config_name: Optional[str] = None, schema: Optional[Any] = None |
|
) -> Callable[[TaskFunction], Any]: |
|
""" |
|
Decorator used for passing the Config paths to main function. |
|
Optionally registers a schema used for validation/providing default values. |
|
|
|
Args: |
|
config_path: Optional path that will be added to config search directory. |
|
NOTE: The default value of `config_path` has changed between Hydra 1.0 and Hydra 1.1+. |
|
Please refer to https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_hydra_main_config_path/ |
|
for details. |
|
config_name: Pathname of the config file. |
|
schema: Structured config type representing the schema used for validation/providing default values. |
|
""" |
|
|
|
def decorator(task_function: TaskFunction) -> Callable[[], None]: |
|
@functools.wraps(task_function) |
|
def wrapper(cfg_passthrough: Optional[DictConfig] = None) -> Any: |
|
|
|
if cfg_passthrough is not None: |
|
return task_function(cfg_passthrough) |
|
else: |
|
args = get_args_parser() |
|
|
|
|
|
parsed_args = args.parse_args() |
|
|
|
|
|
overrides = parsed_args.overrides |
|
|
|
|
|
|
|
overrides.append("hydra.output_subdir=null") |
|
|
|
|
|
overrides.append("hydra/job_logging=stdout") |
|
|
|
|
|
overrides.append("hydra.run.dir=.") |
|
|
|
|
|
if schema is not None: |
|
|
|
cs = ConfigStore.instance() |
|
|
|
|
|
if parsed_args.config_name is not None: |
|
path, name = os.path.split(parsed_args.config_name) |
|
|
|
if path != '': |
|
sys.stderr.write( |
|
f"ERROR Cannot set config file path using `--config-name` when " |
|
"using schema. Please set path using `--config-path` and file name using " |
|
"`--config-name` separately.\n" |
|
) |
|
sys.exit(1) |
|
else: |
|
name = config_name |
|
|
|
|
|
cs.store(name=name, node=schema) |
|
|
|
|
|
|
|
def parse_args(self, args=None, namespace=None): |
|
return parsed_args |
|
|
|
parsed_args.parse_args = parse_args |
|
|
|
|
|
|
|
|
|
argparse_wrapper = parsed_args |
|
|
|
_run_hydra( |
|
args=argparse_wrapper, |
|
args_parser=args, |
|
task_function=task_function, |
|
config_path=config_path, |
|
config_name=config_name, |
|
) |
|
|
|
return wrapper |
|
|
|
return decorator |
|
|