File size: 5,624 Bytes
e85fecb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
"""

Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)

Copyright(c) 2023 lyuwenyu. All Rights Reserved.

"""

import functools
import importlib
import inspect
from collections import defaultdict
from typing import Any, Dict, List, Optional

GLOBAL_CONFIG = defaultdict(dict)


def register(dct: Any = GLOBAL_CONFIG, name=None, force=False):
    """

    dct:

        if dct is Dict, register foo into dct as key-value pair

        if dct is Clas, register as modules attibute

    force

        whether force register.

    """

    def decorator(foo):
        register_name = foo.__name__ if name is None else name
        if not force:
            if inspect.isclass(dct):
                assert not hasattr(dct, foo.__name__), f"module {dct.__name__} has {foo.__name__}"
            else:
                assert foo.__name__ not in dct, f"{foo.__name__} has been already registered"

        if inspect.isfunction(foo):

            @functools.wraps(foo)
            def wrap_func(*args, **kwargs):
                return foo(*args, **kwargs)

            if isinstance(dct, dict):
                dct[foo.__name__] = wrap_func
            elif inspect.isclass(dct):
                setattr(dct, foo.__name__, wrap_func)
            else:
                raise AttributeError("")
            return wrap_func

        elif inspect.isclass(foo):
            dct[register_name] = extract_schema(foo)

        else:
            raise ValueError(f"Do not support {type(foo)} register")

        return foo

    return decorator


def extract_schema(module: type):
    """

    Args:

        module (type),

    Return:

        Dict,

    """
    argspec = inspect.getfullargspec(module.__init__)
    arg_names = [arg for arg in argspec.args if arg != "self"]
    num_defualts = len(argspec.defaults) if argspec.defaults is not None else 0
    num_requires = len(arg_names) - num_defualts

    schame = dict()
    schame["_name"] = module.__name__
    schame["_pymodule"] = importlib.import_module(module.__module__)
    schame["_inject"] = getattr(module, "__inject__", [])
    schame["_share"] = getattr(module, "__share__", [])
    schame["_kwargs"] = {}
    for i, name in enumerate(arg_names):
        if name in schame["_share"]:
            assert i >= num_requires, "share config must have default value."
            value = argspec.defaults[i - num_requires]

        elif i >= num_requires:
            value = argspec.defaults[i - num_requires]

        else:
            value = None

        schame[name] = value
        schame["_kwargs"][name] = value

    return schame


def create(type_or_name, global_cfg=GLOBAL_CONFIG, **kwargs):
    """ """
    assert type(type_or_name) in (type, str), "create should be modules or name."

    name = type_or_name if isinstance(type_or_name, str) else type_or_name.__name__

    if name in global_cfg:
        if hasattr(global_cfg[name], "__dict__"):
            return global_cfg[name]
    else:
        raise ValueError("The module {} is not registered".format(name))

    cfg = global_cfg[name]

    if isinstance(cfg, dict) and "type" in cfg:
        _cfg: dict = global_cfg[cfg["type"]]
        # clean args
        _keys = [k for k in _cfg.keys() if not k.startswith("_")]
        for _arg in _keys:
            del _cfg[_arg]
        _cfg.update(_cfg["_kwargs"])  # restore default args
        _cfg.update(cfg)  # load config args
        _cfg.update(kwargs)  # TODO recive extra kwargs
        name = _cfg.pop("type")  # pop extra key `type` (from cfg)

        return create(name, global_cfg)

    module = getattr(cfg["_pymodule"], name)
    module_kwargs = {}
    module_kwargs.update(cfg)

    # shared var
    for k in cfg["_share"]:
        if k in global_cfg:
            module_kwargs[k] = global_cfg[k]
        else:
            module_kwargs[k] = cfg[k]

    # inject
    for k in cfg["_inject"]:
        _k = cfg[k]

        if _k is None:
            continue

        if isinstance(_k, str):
            if _k not in global_cfg:
                raise ValueError(f"Missing inject config of {_k}.")

            _cfg = global_cfg[_k]

            if isinstance(_cfg, dict):
                module_kwargs[k] = create(_cfg["_name"], global_cfg)
            else:
                module_kwargs[k] = _cfg

        elif isinstance(_k, dict):
            if "type" not in _k.keys():
                raise ValueError("Missing inject for `type` style.")

            _type = str(_k["type"])
            if _type not in global_cfg:
                raise ValueError(f"Missing {_type} in inspect stage.")

            # TODO
            _cfg: dict = global_cfg[_type]
            # clean args
            _keys = [k for k in _cfg.keys() if not k.startswith("_")]
            for _arg in _keys:
                del _cfg[_arg]
            _cfg.update(_cfg["_kwargs"])  # restore default values
            _cfg.update(_k)  # load config args
            name = _cfg.pop("type")  # pop extra key (`type` from _k)
            module_kwargs[k] = create(name, global_cfg)

        else:
            raise ValueError(f"Inject does not support {_k}")

    # TODO hard code
    module_kwargs = {k: v for k, v in module_kwargs.items() if not k.startswith("_")}

    # TODO for **kwargs
    # extra_args = set(module_kwargs.keys()) - set(arg_names)
    # if len(extra_args) > 0:
    #     raise RuntimeError(f'Error: unknown args {extra_args} for {module}')

    return module(**module_kwargs)