AudioMorphix / src /utils /config.py
JinhuaL1ANG's picture
v1
9a6dac6
import os
import argparse
from copy import deepcopy
from typing import Union, Optional
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
def add_args_from_config(config, parser, prefix=""):
r"""Add new arguments to an argparser by using a predefined configs. e.g.,
config = {'a': {'b': 123}}, access the config['a']['b'] by
`python custom.py --a-b 234`.
"""
for key, value in config.items():
# '--' For the args under the root
arg_name = f"-{prefix}-{key}"
# Add args recursively if cfg is a nested DictConfig
if OmegaConf.is_dict(value):
add_args_from_config(value, parser, prefix=f"{prefix}-{key}")
else:
if OmegaConf.is_list(value):
value = OmegaConf.to_container(value)
parser.add_argument(
arg_name, type=type(value[0]), nargs="+", default=None
)
continue
arg_type = type(value)
if arg_type == bool:
parser.add_argument(
arg_name,
action="store_false" if value else "store_true",
default=None,
)
elif arg_type == type(None):
parser.add_argument(
arg_name, default=None,
)
else:
parser.add_argument(
arg_name, type=arg_type, default=None,
)
def update_config_from_args(config, args):
r"""Update an existing config by using a set of arguments.
The arguments should be created by `add_args_from_config`.
"""
def _recur_update_cfgs_from_args(config, args, prefix=""):
cur_config = deepcopy(config)
for key in config:
if OmegaConf.is_dict(config[key]):
updated_cfgs = _recur_update_cfgs_from_args(
config[key], args, prefix=f"{prefix}-{key}"
)
cur_config = OmegaConf.merge(cur_config, {key: updated_cfgs})
else:
arg_name = f"{prefix}-{key}".lstrip("-").replace("-", "_")
if hasattr(args, arg_name):
override_v = getattr(args, arg_name)
cur_config[key] = (
override_v if override_v is not None else config[key]
)
return cur_config
# Update config from each subgroup
for k, v in config["__subgroup__"].items():
sg_cfgs_path = getattr(args, f"__subgroup__-{k}".replace("-", "_"))
if sg_cfgs_path is not None:
updated_sg_cfgs = load_config(sg_cfgs_path)
config = OmegaConf.merge(config, {k: updated_sg_cfgs})
del config.__subgroup__
# Update config from each leaf node
config = _recur_update_cfgs_from_args(config, args, prefix="")
return config
def load_config(
config_path: Union[dict, str, DictConfig], dump_path: Optional[str] = None
) -> dict:
r"""Load config from yaml file.
This function will also read the yaml files
if they are specified in '__subgroup__'. e.g.,
[within `config_path`]
__subgroup__:
a: path_to_yaml_a
b: path_to_yaml_b
...
attribute 1:
...
------
RETURNS: OmegaConf.DictConfig
"""
if isinstance(config_path, str):
with open(config_path, "r") as file:
config = OmegaConf.load(file)
elif isinstance(config_path, dict):
config = OmegaConf.create(config_path)
else:
assert OmegaConf.is_config(
config_path
), f"config_path must be config path, dict, or DictConfig"
config = config_path
if "__subgroup__" in config:
subgroups = config.get("__subgroup__")
cur_cfg_dir = os.path.dirname(os.path.abspath(config_path))
for sg_name, sg_config_path in subgroups.items():
sg_abs_pth = os.path.join(cur_cfg_dir, sg_config_path)
sg_config = OmegaConf.load(sg_abs_pth)
config = OmegaConf.merge(config, {sg_name: sg_config})
config.__subgroup__[sg_name] = sg_abs_pth # update sub cfg path
return config
def dynamic_config(description: Optional[str] = None, verbose: bool = True):
r"""Load configuration from both yaml file and command line.
The config in the yaml will be overrided by the arg passed from command line.
e.g.,
[Command line] python3 custom.py --config_path /path/to/config.yaml --a-b-c=123
[Python file] cfgs = dynamic_config('A demo for dynamic configuration.')
cfgs.to_yaml('path/to/output/config.yaml') # log the config of this trial
------
RETURNS:
DictConfig.
"""
parser = argparse.ArgumentParser(description=description)
parser.add_argument("--config_path", type=str, help="Path to the yaml file.")
# Get predefined configs and add new args dynamically
args, remaining_args = parser.parse_known_args()
cfgs = load_config(args.config_path)
add_args_from_config(cfgs, parser)
# Override values in `cfgs` if applicable
args = parser.parse_args(remaining_args)
cfgs = update_config_from_args(cfgs, args)
if verbose:
import logging
log = logging.getLogger(__name__)
log.info(f"Successfully setup the configuration:\n{OmegaConf.to_yaml(cfgs)}")
return cfgs
def dump_config(cfgs, dump_path):
dump_dir = os.path.dirname(os.path.abspath(dump_path))
os.makedirs(dump_dir, exist_ok=True)
with open(dump_path, "w") as file:
OmegaConf.save(cfgs, f=file)
if __name__ == "__main__":
cfgs = dynamic_config()
print("Updated Configuration:")
print(OmegaConf.to_yaml(cfgs))
import ipdb
ipdb.set_trace()