File size: 3,216 Bytes
e85fecb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""

Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)

Copyright(c) 2023 lyuwenyu. All Rights Reserved.

"""

import copy
import os
from typing import Any, Dict, List, Optional

import yaml

from .workspace import GLOBAL_CONFIG

__all__ = [
    "load_config",
    "merge_config",
    "merge_dict",
    "parse_cli",
]


INCLUDE_KEY = "__include__"


def load_config(file_path, cfg=dict()):
    """load config"""
    _, ext = os.path.splitext(file_path)
    assert ext in [".yml", ".yaml"], "only support yaml files"

    with open(file_path) as f:
        file_cfg = yaml.load(f, Loader=yaml.Loader)
        if file_cfg is None:
            return {}

    if INCLUDE_KEY in file_cfg:
        base_yamls = list(file_cfg[INCLUDE_KEY])
        for base_yaml in base_yamls:
            if base_yaml.startswith("~"):
                base_yaml = os.path.expanduser(base_yaml)

            if not base_yaml.startswith("/"):
                base_yaml = os.path.join(os.path.dirname(file_path), base_yaml)

            with open(base_yaml) as f:
                base_cfg = load_config(base_yaml, cfg)
                merge_dict(cfg, base_cfg)

    return merge_dict(cfg, file_cfg)


def merge_dict(dct, another_dct, inplace=True) -> Dict:
    """merge another_dct into dct"""

    def _merge(dct, another) -> Dict:
        for k in another:
            if k in dct and isinstance(dct[k], dict) and isinstance(another[k], dict):
                _merge(dct[k], another[k])
            else:
                dct[k] = another[k]

        return dct

    if not inplace:
        dct = copy.deepcopy(dct)

    return _merge(dct, another_dct)


def dictify(s: str, v: Any) -> Dict:
    if "." not in s:
        return {s: v}
    key, rest = s.split(".", 1)
    return {key: dictify(rest, v)}


def parse_cli(nargs: List[str]) -> Dict:
    """

    parse command-line arguments

        convert `a.c=3 b=10` to `{'a': {'c': 3}, 'b': 10}`

    """
    cfg = {}
    if nargs is None or len(nargs) == 0:
        return cfg

    for s in nargs:
        s = s.strip()
        k, v = s.split("=", 1)
        d = dictify(k, yaml.load(v, Loader=yaml.Loader))
        cfg = merge_dict(cfg, d)

    return cfg


def merge_config(cfg, another_cfg=GLOBAL_CONFIG, inplace: bool = False, overwrite: bool = False):
    """

    Merge another_cfg into cfg, return the merged config



    Example:



        cfg1 = load_config('./dfine_r18vd_6x_coco.yml')

        cfg1 = merge_config(cfg, inplace=True)



        cfg2 = load_config('./dfine_r50vd_6x_coco.yml')

        cfg2 = merge_config(cfg2, inplace=True)



        model1 = create(cfg1['model'], cfg1)

        model2 = create(cfg2['model'], cfg2)

    """

    def _merge(dct, another):
        for k in another:
            if k not in dct:
                dct[k] = another[k]

            elif isinstance(dct[k], dict) and isinstance(another[k], dict):
                _merge(dct[k], another[k])

            elif overwrite:
                dct[k] = another[k]

        return cfg

    if not inplace:
        cfg = copy.deepcopy(cfg)

    return _merge(cfg, another_cfg)