Spaces:
Sleeping
Sleeping
File size: 8,552 Bytes
5ac1897 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 |
import sys
import json
import rich
import rich.text
import rich.tree
import rich.syntax
import hydra
from typing import List, Optional, Union, Any
from pathlib import Path
from omegaconf import OmegaConf, DictConfig, ListConfig
from pytorch_lightning.utilities import rank_zero_only
from lib.info.log import get_logger
from .proj_manager import ProjManager as PM
def get_PM_info_dict():
''' Get a OmegaConf object containing the information from the ProjManager. '''
PM_info = OmegaConf.create({
'_pm_': {
'root' : str(PM.root),
'inputs' : str(PM.inputs),
'outputs': str(PM.outputs),
}
})
return PM_info
def get_PM_info_list():
''' Get a list containing the information from the ProjManager. '''
PM_info = [
f'_pm_.root={str(PM.root)}',
f'_pm_.inputs={str(PM.inputs)}',
f'_pm_.outputs={str(PM.outputs)}',
]
return PM_info
def entrypoint_with_args(*args, log_cfg=True, **kwargs):
'''
This decorator extends the `hydra.main` decorator in these parts:
- Inject some runtime-known arguments, e.g., `proj_root`.
- Enable additional arguments that needn't to be specified in command line.
- Positional arguments are added to the command line arguments directly, so make sure they are valid.
- e.g., \'exp=<...>\', \'+extra=<...>\', etc.
- Key-specified arguments have the same effect as command line arguments {k}={v}.
- Check the validation of experiment name.
'''
overrides = get_PM_info_list()
for arg in args:
overrides.append(arg)
for k, v in kwargs.items():
overrides.append(f'{k}={v}')
overrides.extend(sys.argv[1:])
def entrypoint_wrapper(func):
# Import extra pre-specified arguments.
if len(overrides) > 0:
# The args from command line have higher priority, so put them in the back.
sys.argv = sys.argv[:1] + overrides + sys.argv[1:]
_log_exp_info(func.__name__, overrides)
@hydra.main(version_base=None, config_path=str(PM.configs), config_name='base.yaml')
def entrypoint_preprocess(cfg:DictConfig):
# Resolve the references and make it editable.
cfg = unfold_cfg(cfg)
# Print out the configuration files.
if log_cfg and cfg.get('show_cfg', True):
sum_keys = ['output_dir', 'pipeline.name', 'data.name', 'exp_name', 'exp_tag']
print_cfg(cfg, sum_keys=sum_keys)
# Check the validation of experiment name.
if cfg.get('exp_name') is None:
get_logger(brief=True).fatal(f'`exp_name` is not given! You may need to add `exp=<certain_exp>` to the command line.')
raise ValueError('`exp_name` is not given!')
# Bind config.
PM.init_with_cfg(cfg)
try:
with PM.time_monitor('exp', f'Main part of experiment `{cfg.exp_name}`.'):
# Enter the main function.
func(cfg)
except Exception as e:
raise e
finally:
PM.time_monitor.report(level='global')
# TODO: Wrap a notifier here.
return entrypoint_preprocess
return entrypoint_wrapper
#! This implementation can't dump the config files in default ways. In order to keep c
# def entrypoint_wrapper(func):
# def entrypoint_preprocess():
# # Initialize the configuration module.
# with hydra.initialize_config_dir(version_base=None, config_dir=str(PM.configs)):
# get_logger(brief=True).info(f'Exp entry `{func.__name__}` is called with overrides: {overrides}')
# cfg = hydra.compose(config_name='base', overrides=overrides)
# cfg4dump_raw = cfg.copy() # store the folded raw configuration files
# # Resolve the references and make it editable.
# cfg = unfold_cfg(cfg)
# # Print out the configuration files.
# if log_cfg:
# sum_keys = ['pipeline.name', 'data.name', 'exp_name']
# print_cfg(cfg, sum_keys=sum_keys)
# # Check the validation of experiment name.
# if cfg.get('exp_name') is None:
# get_logger().fatal(f'`exp_name` is not given! You may need to add `exp=<certain_exp>` to the command line.')
# raise ValueError('`exp_name` is not given!')
# # Enter the main function.
# func(cfg)
# return entrypoint_preprocess
# return entrypoint_wrapper
def entrypoint(func):
'''
This decorator extends the `hydra.main` decorator in these parts:
- Inject some runtime-known arguments, e.g., `proj_root`.
- Check the validation of experiment name.
'''
return entrypoint_with_args()(func)
def unfold_cfg(
cfg : Union[DictConfig, Any],
):
'''
Unfold the configuration files, i.e. from structured mode to container mode and recreate the
configuration files. It will resolve all the references and make the config editable.
### Args
- cfg: DictConfig or None
### Returns
- cfg: DictConfig or None
'''
if cfg is None:
return None
cfg_container = OmegaConf.to_container(cfg, resolve=True)
cfg = OmegaConf.create(cfg_container)
return cfg
def recursively_simplify_cfg(
node : DictConfig,
hide_misc : bool = True,
):
if isinstance(node, DictConfig):
for k in list(node.keys()):
# We delete some terms that are not commonly concerned.
if hide_misc:
if k in ['_hub_', 'hydra', 'job_logging']:
node.__delattr__(k)
continue
node[k] = recursively_simplify_cfg(node[k], hide_misc)
elif isinstance(node, ListConfig):
if len(node) > 0 and all([
not isinstance(x, DictConfig) \
and not isinstance(x, ListConfig) \
for x in node
]):
# We fold all lists of basic elements (int, float, ...) into a single line if possible.
folded_list_str = '*' + str(list(node))
node = folded_list_str if len(folded_list_str) < 320 else node
else:
for i in range(len(node)):
node[i] = recursively_simplify_cfg(node[i], hide_misc)
return node
@rank_zero_only
def print_cfg(
cfg : Optional[DictConfig],
title : str ='cfg',
sum_keys: List[str] = [],
show_all: bool = False
):
'''
Print configuration files using rich.
### Args
- cfg: DictConfig or None
- If None, print nothing.
- sum_keys: List[str], default []
- If keys given in the list exist in the first level of the configuration files,
they will be printed in the summary part.
- show_all: bool, default False
- If False, hide terms starts with `_` in the configuration files's first level
and some hydra supporting configs.
'''
theme = 'coffee'
style = 'dim'
tf_dict = { True: '◼', False: '◻' }
print_setting = f'<< {tf_dict[show_all]} SHOW_ALL >>'
tree = rich.tree.Tree(f'⌾ {title} - {print_setting}', style=style, guide_style=style)
if cfg is None:
tree.add('None')
rich.print(tree)
return
# Clone a new one to avoid changing the original configuration files.
cfg = cfg.copy()
cfg = unfold_cfg(cfg)
if not show_all:
cfg = recursively_simplify_cfg(cfg)
cfg_yaml = OmegaConf.to_yaml(cfg)
cfg_yaml = rich.syntax.Syntax(cfg_yaml, 'yaml', theme=theme, line_numbers=True)
tree.add(cfg_yaml)
# Add a summary containing information only is commonly concerned.
if len(sum_keys) > 0:
concerned = {}
for k_str in sum_keys:
k_list = k_str.split('.')
tgt = cfg
for k in k_list:
if tgt is not None:
tgt = tgt.get(k)
if tgt is not None:
concerned[k_str] = tgt
else:
get_logger().warning(f'Key `{k_str}` is not found in the configuration files.')
tree.add(rich.syntax.Syntax(OmegaConf.to_yaml(concerned), 'yaml', theme=theme))
rich.print(tree)
@rank_zero_only
def _log_exp_info(
func_name : str,
overrides : List[str],
):
get_logger(brief=True).info(f'Exp entry `{func_name}` is called with overrides: {overrides}') |