File size: 8,552 Bytes
5ac1897
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import sys
import json
import rich
import rich.text
import rich.tree
import rich.syntax
import hydra
from typing import List, Optional, Union, Any
from pathlib import Path
from omegaconf import OmegaConf, DictConfig, ListConfig
from pytorch_lightning.utilities import rank_zero_only

from lib.info.log import get_logger

from .proj_manager import ProjManager as PM


def get_PM_info_dict():
    ''' Get a OmegaConf object containing the information from the ProjManager. '''
    PM_info = OmegaConf.create({
            '_pm_': {
                'root'   : str(PM.root),
                'inputs' : str(PM.inputs),
                'outputs': str(PM.outputs),
            }
        })
    return PM_info


def get_PM_info_list():
    ''' Get a list containing the information from the ProjManager. '''
    PM_info = [
        f'_pm_.root={str(PM.root)}',
        f'_pm_.inputs={str(PM.inputs)}',
        f'_pm_.outputs={str(PM.outputs)}',
    ]
    return PM_info


def entrypoint_with_args(*args, log_cfg=True, **kwargs):
    '''
    This decorator extends the `hydra.main` decorator in these parts:
    - Inject some runtime-known arguments, e.g., `proj_root`.
    - Enable additional arguments that needn't to be specified in command line.
        - Positional arguments are added to the command line arguments directly, so make sure they are valid.
            - e.g., \'exp=<...>\', \'+extra=<...>\', etc.
        - Key-specified arguments have the same effect as command line arguments {k}={v}.
    - Check the validation of experiment name.
    '''

    overrides = get_PM_info_list()

    for arg in args:
        overrides.append(arg)

    for k, v in kwargs.items():
        overrides.append(f'{k}={v}')

    overrides.extend(sys.argv[1:])

    def entrypoint_wrapper(func):
        # Import extra pre-specified arguments.
        if len(overrides) > 0:
            # The args from command line have higher priority, so put them in the back.
            sys.argv = sys.argv[:1] + overrides + sys.argv[1:]
            _log_exp_info(func.__name__, overrides)

        @hydra.main(version_base=None, config_path=str(PM.configs), config_name='base.yaml')
        def entrypoint_preprocess(cfg:DictConfig):
            # Resolve the references and make it editable.
            cfg = unfold_cfg(cfg)

            # Print out the configuration files.
            if log_cfg and cfg.get('show_cfg', True):
                sum_keys = ['output_dir', 'pipeline.name', 'data.name', 'exp_name', 'exp_tag']
                print_cfg(cfg, sum_keys=sum_keys)

            # Check the validation of experiment name.
            if cfg.get('exp_name') is None:
                get_logger(brief=True).fatal(f'`exp_name` is not given! You may need to add `exp=<certain_exp>` to the command line.')
                raise ValueError('`exp_name` is not given!')

            # Bind config.
            PM.init_with_cfg(cfg)
            try:
                with PM.time_monitor('exp', f'Main part of experiment `{cfg.exp_name}`.'):
                    # Enter the main function.
                    func(cfg)
            except Exception as e:
                raise e
            finally:
                PM.time_monitor.report(level='global')

            # TODO: Wrap a notifier here.

        return entrypoint_preprocess


    return entrypoint_wrapper

    #! This implementation can't dump the config files in default ways. In order to keep c
    # def entrypoint_wrapper(func):
    #     def entrypoint_preprocess():
    #         # Initialize the configuration module.
    #         with hydra.initialize_config_dir(version_base=None, config_dir=str(PM.configs)):
    #             get_logger(brief=True).info(f'Exp entry `{func.__name__}` is called with overrides: {overrides}')
    #             cfg = hydra.compose(config_name='base', overrides=overrides)

    #         cfg4dump_raw = cfg.copy()  # store the folded raw configuration files
    #         # Resolve the references and make it editable.
    #         cfg = unfold_cfg(cfg)

    #         # Print out the configuration files.
    #         if log_cfg:
    #             sum_keys = ['pipeline.name', 'data.name', 'exp_name']
    #             print_cfg(cfg, sum_keys=sum_keys)
    #         # Check the validation of experiment name.
    #         if cfg.get('exp_name') is None:
    #             get_logger().fatal(f'`exp_name` is not given! You may need to add `exp=<certain_exp>` to the command line.')
    #             raise ValueError('`exp_name` is not given!')
    #         # Enter the main function.
    #         func(cfg)
    #     return entrypoint_preprocess
    # return entrypoint_wrapper

def entrypoint(func):
    '''
    This decorator extends the `hydra.main` decorator in these parts:
    - Inject some runtime-known arguments, e.g., `proj_root`.
    - Check the validation of experiment name.
    '''
    return entrypoint_with_args()(func)


def unfold_cfg(
    cfg : Union[DictConfig, Any],
):
    '''
    Unfold the configuration files, i.e. from structured mode to container mode and recreate the
    configuration files. It will resolve all the references and make the config editable.

    ### Args
    - cfg: DictConfig or None

    ### Returns
    - cfg: DictConfig or None
    '''
    if cfg is None:
        return None

    cfg_container = OmegaConf.to_container(cfg, resolve=True)
    cfg = OmegaConf.create(cfg_container)
    return cfg


def recursively_simplify_cfg(
    node      : DictConfig,
    hide_misc : bool = True,
):
    if isinstance(node, DictConfig):
        for k in list(node.keys()):
            # We delete some terms that are not commonly concerned.
            if hide_misc:
                if k in ['_hub_', 'hydra', 'job_logging']:
                    node.__delattr__(k)
                    continue
            node[k] = recursively_simplify_cfg(node[k], hide_misc)
    elif isinstance(node, ListConfig):
        if len(node) > 0 and all([
                not isinstance(x, DictConfig) \
            and not isinstance(x, ListConfig) \
            for x in node
        ]):
            # We fold all lists of basic elements (int, float, ...) into a single line if possible.
            folded_list_str = '*' + str(list(node))
            node = folded_list_str if len(folded_list_str) < 320 else node
        else:
            for i in range(len(node)):
                node[i] = recursively_simplify_cfg(node[i], hide_misc)
    return node


@rank_zero_only
def print_cfg(
    cfg     : Optional[DictConfig],
    title   : str  ='cfg',
    sum_keys: List[str] = [],
    show_all: bool = False
):
    '''
    Print configuration files using rich.

    ### Args
    - cfg: DictConfig or None
        - If None, print nothing.
    - sum_keys: List[str], default []
        - If keys given in the list exist in the first level of the configuration files,
          they will be printed in the summary part.
    - show_all: bool, default False
        - If False, hide terms starts with `_` in the configuration files's first level
          and some hydra supporting configs.
    '''

    theme = 'coffee'
    style = 'dim'

    tf_dict = { True: '◼', False: '◻' }
    print_setting = f'<< {tf_dict[show_all]} SHOW_ALL >>'
    tree = rich.tree.Tree(f'⌾ {title} - {print_setting}', style=style, guide_style=style)

    if cfg is None:
        tree.add('None')
        rich.print(tree)
        return

    # Clone a new one to avoid changing the original configuration files.
    cfg = cfg.copy()
    cfg = unfold_cfg(cfg)

    if not show_all:
        cfg = recursively_simplify_cfg(cfg)

    cfg_yaml = OmegaConf.to_yaml(cfg)
    cfg_yaml = rich.syntax.Syntax(cfg_yaml, 'yaml', theme=theme, line_numbers=True)
    tree.add(cfg_yaml)

    # Add a summary containing information only is commonly concerned.
    if len(sum_keys) > 0:
        concerned = {}
        for k_str in sum_keys:
            k_list = k_str.split('.')
            tgt = cfg
            for k in k_list:
                if tgt is not None:
                    tgt = tgt.get(k)
            if tgt is not None:
                concerned[k_str] = tgt
            else:
                get_logger().warning(f'Key `{k_str}` is not found in the configuration files.')

        tree.add(rich.syntax.Syntax(OmegaConf.to_yaml(concerned), 'yaml', theme=theme))

    rich.print(tree)


@rank_zero_only
def _log_exp_info(
    func_name : str,
    overrides : List[str],
):
    get_logger(brief=True).info(f'Exp entry `{func_name}` is called with overrides: {overrides}')