Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import argparse | |
from copy import deepcopy | |
from typing import Union, Optional | |
from omegaconf import OmegaConf | |
from omegaconf.dictconfig import DictConfig | |
def add_args_from_config(config, parser, prefix=""): | |
r"""Add new arguments to an argparser by using a predefined configs. e.g., | |
config = {'a': {'b': 123}}, access the config['a']['b'] by | |
`python custom.py --a-b 234`. | |
""" | |
for key, value in config.items(): | |
# '--' For the args under the root | |
arg_name = f"-{prefix}-{key}" | |
# Add args recursively if cfg is a nested DictConfig | |
if OmegaConf.is_dict(value): | |
add_args_from_config(value, parser, prefix=f"{prefix}-{key}") | |
else: | |
if OmegaConf.is_list(value): | |
value = OmegaConf.to_container(value) | |
parser.add_argument( | |
arg_name, type=type(value[0]), nargs="+", default=None | |
) | |
continue | |
arg_type = type(value) | |
if arg_type == bool: | |
parser.add_argument( | |
arg_name, | |
action="store_false" if value else "store_true", | |
default=None, | |
) | |
elif arg_type == type(None): | |
parser.add_argument( | |
arg_name, default=None, | |
) | |
else: | |
parser.add_argument( | |
arg_name, type=arg_type, default=None, | |
) | |
def update_config_from_args(config, args): | |
r"""Update an existing config by using a set of arguments. | |
The arguments should be created by `add_args_from_config`. | |
""" | |
def _recur_update_cfgs_from_args(config, args, prefix=""): | |
cur_config = deepcopy(config) | |
for key in config: | |
if OmegaConf.is_dict(config[key]): | |
updated_cfgs = _recur_update_cfgs_from_args( | |
config[key], args, prefix=f"{prefix}-{key}" | |
) | |
cur_config = OmegaConf.merge(cur_config, {key: updated_cfgs}) | |
else: | |
arg_name = f"{prefix}-{key}".lstrip("-").replace("-", "_") | |
if hasattr(args, arg_name): | |
override_v = getattr(args, arg_name) | |
cur_config[key] = ( | |
override_v if override_v is not None else config[key] | |
) | |
return cur_config | |
# Update config from each subgroup | |
for k, v in config["__subgroup__"].items(): | |
sg_cfgs_path = getattr(args, f"__subgroup__-{k}".replace("-", "_")) | |
if sg_cfgs_path is not None: | |
updated_sg_cfgs = load_config(sg_cfgs_path) | |
config = OmegaConf.merge(config, {k: updated_sg_cfgs}) | |
del config.__subgroup__ | |
# Update config from each leaf node | |
config = _recur_update_cfgs_from_args(config, args, prefix="") | |
return config | |
def load_config( | |
config_path: Union[dict, str, DictConfig], dump_path: Optional[str] = None | |
) -> dict: | |
r"""Load config from yaml file. | |
This function will also read the yaml files | |
if they are specified in '__subgroup__'. e.g., | |
[within `config_path`] | |
__subgroup__: | |
a: path_to_yaml_a | |
b: path_to_yaml_b | |
... | |
attribute 1: | |
... | |
------ | |
RETURNS: OmegaConf.DictConfig | |
""" | |
if isinstance(config_path, str): | |
with open(config_path, "r") as file: | |
config = OmegaConf.load(file) | |
elif isinstance(config_path, dict): | |
config = OmegaConf.create(config_path) | |
else: | |
assert OmegaConf.is_config( | |
config_path | |
), f"config_path must be config path, dict, or DictConfig" | |
config = config_path | |
if "__subgroup__" in config: | |
subgroups = config.get("__subgroup__") | |
cur_cfg_dir = os.path.dirname(os.path.abspath(config_path)) | |
for sg_name, sg_config_path in subgroups.items(): | |
sg_abs_pth = os.path.join(cur_cfg_dir, sg_config_path) | |
sg_config = OmegaConf.load(sg_abs_pth) | |
config = OmegaConf.merge(config, {sg_name: sg_config}) | |
config.__subgroup__[sg_name] = sg_abs_pth # update sub cfg path | |
return config | |
def dynamic_config(description: Optional[str] = None, verbose: bool = True): | |
r"""Load configuration from both yaml file and command line. | |
The config in the yaml will be overrided by the arg passed from command line. | |
e.g., | |
[Command line] python3 custom.py --config_path /path/to/config.yaml --a-b-c=123 | |
[Python file] cfgs = dynamic_config('A demo for dynamic configuration.') | |
cfgs.to_yaml('path/to/output/config.yaml') # log the config of this trial | |
------ | |
RETURNS: | |
DictConfig. | |
""" | |
parser = argparse.ArgumentParser(description=description) | |
parser.add_argument("--config_path", type=str, help="Path to the yaml file.") | |
# Get predefined configs and add new args dynamically | |
args, remaining_args = parser.parse_known_args() | |
cfgs = load_config(args.config_path) | |
add_args_from_config(cfgs, parser) | |
# Override values in `cfgs` if applicable | |
args = parser.parse_args(remaining_args) | |
cfgs = update_config_from_args(cfgs, args) | |
if verbose: | |
import logging | |
log = logging.getLogger(__name__) | |
log.info(f"Successfully setup the configuration:\n{OmegaConf.to_yaml(cfgs)}") | |
return cfgs | |
def dump_config(cfgs, dump_path): | |
dump_dir = os.path.dirname(os.path.abspath(dump_path)) | |
os.makedirs(dump_dir, exist_ok=True) | |
with open(dump_path, "w") as file: | |
OmegaConf.save(cfgs, f=file) | |
if __name__ == "__main__": | |
cfgs = dynamic_config() | |
print("Updated Configuration:") | |
print(OmegaConf.to_yaml(cfgs)) | |
import ipdb | |
ipdb.set_trace() | |