Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,624 Bytes
e85fecb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
"""
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
import functools
import importlib
import inspect
from collections import defaultdict
from typing import Any, Dict, List, Optional
GLOBAL_CONFIG = defaultdict(dict)
def register(dct: Any = GLOBAL_CONFIG, name=None, force=False):
"""
dct:
if dct is Dict, register foo into dct as key-value pair
if dct is Clas, register as modules attibute
force
whether force register.
"""
def decorator(foo):
register_name = foo.__name__ if name is None else name
if not force:
if inspect.isclass(dct):
assert not hasattr(dct, foo.__name__), f"module {dct.__name__} has {foo.__name__}"
else:
assert foo.__name__ not in dct, f"{foo.__name__} has been already registered"
if inspect.isfunction(foo):
@functools.wraps(foo)
def wrap_func(*args, **kwargs):
return foo(*args, **kwargs)
if isinstance(dct, dict):
dct[foo.__name__] = wrap_func
elif inspect.isclass(dct):
setattr(dct, foo.__name__, wrap_func)
else:
raise AttributeError("")
return wrap_func
elif inspect.isclass(foo):
dct[register_name] = extract_schema(foo)
else:
raise ValueError(f"Do not support {type(foo)} register")
return foo
return decorator
def extract_schema(module: type):
"""
Args:
module (type),
Return:
Dict,
"""
argspec = inspect.getfullargspec(module.__init__)
arg_names = [arg for arg in argspec.args if arg != "self"]
num_defualts = len(argspec.defaults) if argspec.defaults is not None else 0
num_requires = len(arg_names) - num_defualts
schame = dict()
schame["_name"] = module.__name__
schame["_pymodule"] = importlib.import_module(module.__module__)
schame["_inject"] = getattr(module, "__inject__", [])
schame["_share"] = getattr(module, "__share__", [])
schame["_kwargs"] = {}
for i, name in enumerate(arg_names):
if name in schame["_share"]:
assert i >= num_requires, "share config must have default value."
value = argspec.defaults[i - num_requires]
elif i >= num_requires:
value = argspec.defaults[i - num_requires]
else:
value = None
schame[name] = value
schame["_kwargs"][name] = value
return schame
def create(type_or_name, global_cfg=GLOBAL_CONFIG, **kwargs):
""" """
assert type(type_or_name) in (type, str), "create should be modules or name."
name = type_or_name if isinstance(type_or_name, str) else type_or_name.__name__
if name in global_cfg:
if hasattr(global_cfg[name], "__dict__"):
return global_cfg[name]
else:
raise ValueError("The module {} is not registered".format(name))
cfg = global_cfg[name]
if isinstance(cfg, dict) and "type" in cfg:
_cfg: dict = global_cfg[cfg["type"]]
# clean args
_keys = [k for k in _cfg.keys() if not k.startswith("_")]
for _arg in _keys:
del _cfg[_arg]
_cfg.update(_cfg["_kwargs"]) # restore default args
_cfg.update(cfg) # load config args
_cfg.update(kwargs) # TODO recive extra kwargs
name = _cfg.pop("type") # pop extra key `type` (from cfg)
return create(name, global_cfg)
module = getattr(cfg["_pymodule"], name)
module_kwargs = {}
module_kwargs.update(cfg)
# shared var
for k in cfg["_share"]:
if k in global_cfg:
module_kwargs[k] = global_cfg[k]
else:
module_kwargs[k] = cfg[k]
# inject
for k in cfg["_inject"]:
_k = cfg[k]
if _k is None:
continue
if isinstance(_k, str):
if _k not in global_cfg:
raise ValueError(f"Missing inject config of {_k}.")
_cfg = global_cfg[_k]
if isinstance(_cfg, dict):
module_kwargs[k] = create(_cfg["_name"], global_cfg)
else:
module_kwargs[k] = _cfg
elif isinstance(_k, dict):
if "type" not in _k.keys():
raise ValueError("Missing inject for `type` style.")
_type = str(_k["type"])
if _type not in global_cfg:
raise ValueError(f"Missing {_type} in inspect stage.")
# TODO
_cfg: dict = global_cfg[_type]
# clean args
_keys = [k for k in _cfg.keys() if not k.startswith("_")]
for _arg in _keys:
del _cfg[_arg]
_cfg.update(_cfg["_kwargs"]) # restore default values
_cfg.update(_k) # load config args
name = _cfg.pop("type") # pop extra key (`type` from _k)
module_kwargs[k] = create(name, global_cfg)
else:
raise ValueError(f"Inject does not support {_k}")
# TODO hard code
module_kwargs = {k: v for k, v in module_kwargs.items() if not k.startswith("_")}
# TODO for **kwargs
# extra_args = set(module_kwargs.keys()) - set(arg_names)
# if len(extra_args) > 0:
# raise RuntimeError(f'Error: unknown args {extra_args} for {module}')
return module(**module_kwargs)
|