Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,818 Bytes
9a6dac6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import os
import argparse
from copy import deepcopy
from typing import Union, Optional
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
def add_args_from_config(config, parser, prefix=""):
r"""Add new arguments to an argparser by using a predefined configs. e.g.,
config = {'a': {'b': 123}}, access the config['a']['b'] by
`python custom.py --a-b 234`.
"""
for key, value in config.items():
# '--' For the args under the root
arg_name = f"-{prefix}-{key}"
# Add args recursively if cfg is a nested DictConfig
if OmegaConf.is_dict(value):
add_args_from_config(value, parser, prefix=f"{prefix}-{key}")
else:
if OmegaConf.is_list(value):
value = OmegaConf.to_container(value)
parser.add_argument(
arg_name, type=type(value[0]), nargs="+", default=None
)
continue
arg_type = type(value)
if arg_type == bool:
parser.add_argument(
arg_name,
action="store_false" if value else "store_true",
default=None,
)
elif arg_type == type(None):
parser.add_argument(
arg_name, default=None,
)
else:
parser.add_argument(
arg_name, type=arg_type, default=None,
)
def update_config_from_args(config, args):
r"""Update an existing config by using a set of arguments.
The arguments should be created by `add_args_from_config`.
"""
def _recur_update_cfgs_from_args(config, args, prefix=""):
cur_config = deepcopy(config)
for key in config:
if OmegaConf.is_dict(config[key]):
updated_cfgs = _recur_update_cfgs_from_args(
config[key], args, prefix=f"{prefix}-{key}"
)
cur_config = OmegaConf.merge(cur_config, {key: updated_cfgs})
else:
arg_name = f"{prefix}-{key}".lstrip("-").replace("-", "_")
if hasattr(args, arg_name):
override_v = getattr(args, arg_name)
cur_config[key] = (
override_v if override_v is not None else config[key]
)
return cur_config
# Update config from each subgroup
for k, v in config["__subgroup__"].items():
sg_cfgs_path = getattr(args, f"__subgroup__-{k}".replace("-", "_"))
if sg_cfgs_path is not None:
updated_sg_cfgs = load_config(sg_cfgs_path)
config = OmegaConf.merge(config, {k: updated_sg_cfgs})
del config.__subgroup__
# Update config from each leaf node
config = _recur_update_cfgs_from_args(config, args, prefix="")
return config
def load_config(
config_path: Union[dict, str, DictConfig], dump_path: Optional[str] = None
) -> dict:
r"""Load config from yaml file.
This function will also read the yaml files
if they are specified in '__subgroup__'. e.g.,
[within `config_path`]
__subgroup__:
a: path_to_yaml_a
b: path_to_yaml_b
...
attribute 1:
...
------
RETURNS: OmegaConf.DictConfig
"""
if isinstance(config_path, str):
with open(config_path, "r") as file:
config = OmegaConf.load(file)
elif isinstance(config_path, dict):
config = OmegaConf.create(config_path)
else:
assert OmegaConf.is_config(
config_path
), f"config_path must be config path, dict, or DictConfig"
config = config_path
if "__subgroup__" in config:
subgroups = config.get("__subgroup__")
cur_cfg_dir = os.path.dirname(os.path.abspath(config_path))
for sg_name, sg_config_path in subgroups.items():
sg_abs_pth = os.path.join(cur_cfg_dir, sg_config_path)
sg_config = OmegaConf.load(sg_abs_pth)
config = OmegaConf.merge(config, {sg_name: sg_config})
config.__subgroup__[sg_name] = sg_abs_pth # update sub cfg path
return config
def dynamic_config(description: Optional[str] = None, verbose: bool = True):
r"""Load configuration from both yaml file and command line.
The config in the yaml will be overrided by the arg passed from command line.
e.g.,
[Command line] python3 custom.py --config_path /path/to/config.yaml --a-b-c=123
[Python file] cfgs = dynamic_config('A demo for dynamic configuration.')
cfgs.to_yaml('path/to/output/config.yaml') # log the config of this trial
------
RETURNS:
DictConfig.
"""
parser = argparse.ArgumentParser(description=description)
parser.add_argument("--config_path", type=str, help="Path to the yaml file.")
# Get predefined configs and add new args dynamically
args, remaining_args = parser.parse_known_args()
cfgs = load_config(args.config_path)
add_args_from_config(cfgs, parser)
# Override values in `cfgs` if applicable
args = parser.parse_args(remaining_args)
cfgs = update_config_from_args(cfgs, args)
if verbose:
import logging
log = logging.getLogger(__name__)
log.info(f"Successfully setup the configuration:\n{OmegaConf.to_yaml(cfgs)}")
return cfgs
def dump_config(cfgs, dump_path):
dump_dir = os.path.dirname(os.path.abspath(dump_path))
os.makedirs(dump_dir, exist_ok=True)
with open(dump_path, "w") as file:
OmegaConf.save(cfgs, f=file)
if __name__ == "__main__":
cfgs = dynamic_config()
print("Updated Configuration:")
print(OmegaConf.to_yaml(cfgs))
import ipdb
ipdb.set_trace()
|