Spaces:
Runtime error
Runtime error
""" | |
Copyright (c) Meta Platforms, Inc. and affiliates. | |
All rights reserved. | |
This source code is licensed under the license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
import copy | |
import importlib | |
import inspect | |
import logging | |
from dataclasses import dataclass, field | |
from typing import Any, Dict, Optional | |
from attrdict import AttrDict | |
from torch import nn | |
logger: logging.Logger = logging.getLogger(__name__) | |
def load_module( | |
module_name: str, class_name: Optional[str] = None, silent: bool = False | |
): | |
""" | |
Load a module or class given the module/class name. | |
Example: | |
.. code-block:: python | |
eye_geo = load_class("path.to.module", "ClassName") | |
Args: | |
module_name: str | |
The full path of the module relative to the root directory. Ex: ``utils.module_loader`` | |
class_name: str | |
The name of the class within the module to load. | |
silent: bool | |
If set to True, return None instead of raising an exception if module/class is missing | |
Returns: | |
object: | |
The loaded module or class object. | |
""" | |
try: | |
module = importlib.import_module(f"visualize.{module_name}") | |
if class_name: | |
return getattr(module, class_name) | |
else: | |
return module | |
except ModuleNotFoundError as e: | |
if silent: | |
return None | |
logger.error(f"Module not found: {module_name}", exc_info=True) | |
raise | |
except AttributeError as e: | |
if silent: | |
return None | |
logger.error( | |
f"Can not locate class: {class_name} in {module_name}.", exc_info=True | |
) | |
raise | |
# pyre-ignore[3] | |
def make_module(mod_config: AttrDict, *args: Any, **kwargs: Any) -> Any: | |
""" | |
A shortcut for making an object given the config and arguments | |
Args: | |
mod_config: AttrDict | |
Config. Should contain keys: module_name, class_name, and optionally args | |
*args | |
Positional arguments. | |
**kwargs | |
Default keyword arguments. Overwritten by content from mod_config.args | |
Returns: | |
object: | |
The loaded module or class object. | |
""" | |
mod_config_dict = dict(mod_config) | |
mod_args = mod_config_dict.pop("args", {}) | |
mod_args.update({k: v for k, v in kwargs.items() if k not in mod_args.keys()}) | |
mod_class = load_module(**mod_config_dict) | |
return mod_class(*args, **mod_args) | |
def get_full_name(mod: object) -> str: | |
""" | |
Returns a name of an object in a form <module>.<parent_scope>.<name> | |
""" | |
mod_class = mod.__class__ | |
return f"{mod_class.__module__}.{mod_class.__qualname__}" | |
# pyre-fixme[3]: Return type must be annotated. | |
def load_class(class_name: str): | |
""" | |
Load a class given the full class name. | |
Example: | |
.. code-block:: python | |
class_instance = load_class("module.path.ClassName") | |
Args: | |
class_name: txt | |
The full class name including the full path of the module relative to the root directory. | |
Returns: | |
A class | |
""" | |
# This is a false-positive, pyre doesn't understand rsplit(..., 1) can only have 1-2 elements | |
# pyre-fixme[6]: In call `load_module`, for 1st positional only parameter expected `bool` but got `str`. | |
return load_module(*class_name.rsplit(".", 1)) | |
class ObjectSpec: | |
""" | |
Args: | |
class_name: str | |
The full class name including the full path of the module relative to | |
the root directory or just the name of the class within the module to | |
load when module name is also provided. | |
module_name: str | |
The full path of the module relative to the root directory. Ex: ``utils.module_loader`` | |
kwargs: dict | |
Keyword arguments for initializing the object. | |
""" | |
class_name: str | |
module_name: Optional[str] = None | |
kwargs: Dict[str, Any] = field(default_factory=dict) | |
# pyre-fixme[3]: Return type must be annotated. | |
def load_object(spec: ObjectSpec, **kwargs: Any): | |
""" | |
Instantiate an object given the class name and initialization arguments. | |
Example: | |
.. code-block:: python | |
my_model = load_object(ObjectSpec(**my_model_config), in_channels=3) | |
Args: | |
spec: ObjectSpec | |
An ObjectSpec object that specifies the class name and init arguments. | |
kwargs: dict | |
Additional keyword arguments for initialization. | |
Returns: | |
An object | |
""" | |
if spec.module_name is None: | |
object_class = load_class(spec.class_name) | |
else: | |
object_class = load_module(spec.module_name, spec.class_name) | |
# Debug message for overriding the object spec | |
for key in kwargs: | |
if key in spec.kwargs: | |
logger.debug(f"Overriding {key} as {kwargs[key]} in {spec}.") | |
return object_class(**{**spec.kwargs, **kwargs}) | |
# From DaaT merge. Fix here T145981161 | |
# pyre-fixme[2]: parameter must be annotated. | |
# pyre-fixme[3]: Return type must be annotated. | |
def load_from_config(config: AttrDict, **kwargs): | |
"""Instantiate an object given a config and arguments.""" | |
assert "class_name" in config and "module_name" not in config | |
config = copy.deepcopy(config) | |
class_name = config.pop("class_name") | |
object_class = load_class(class_name) | |
return object_class(**config, **kwargs) | |
# From DaaT merge. Fix here T145981161 | |
# pyre-fixme[2]: parameter must be annotated. | |
# pyre-fixme[3]: Return type must be annotated. | |
def forward_parameter_names(module): | |
"""Get the names arguments of the forward pass for the module. | |
Args: | |
module: a class with `forward()` method | |
""" | |
names = [] | |
params = list(inspect.signature(module.forward).parameters.values())[1:] | |
for p in params: | |
if p.name in {"*args", "**kwargs"}: | |
raise ValueError("*args and **kwargs are not supported") | |
names.append(p.name) | |
return names | |
# From DaaT merge. Fix here T145981161 | |
def build_optimizer(config, model): | |
"""Build an optimizer given optimizer config and a model. | |
Args: | |
config: DictConfig | |
model: nn.Module|Dict[str,nn.Module] | |
""" | |
config = copy.deepcopy(config) | |
if isinstance(model, nn.Module): | |
if "per_module" in config: | |
params = [] | |
for name, value in config.per_module.items(): | |
if not hasattr(model, name): | |
logger.warning( | |
f"model {model.__class__} does not have a submodule {name}, skipping" | |
) | |
continue | |
params.append( | |
dict( | |
params=getattr(model, name).parameters(), | |
**value, | |
) | |
) | |
defined_names = set(config.per_module.keys()) | |
for name, module in model.named_children(): | |
n_params = len(list(module.named_parameters())) | |
if name not in defined_names and n_params: | |
logger.warning( | |
f"not going to optimize module {name} which has {n_params} parameters" | |
) | |
config.pop("per_module") | |
else: | |
params = model.parameters() | |
else: | |
# NOTE: can we do | |
assert "per_module" in config | |
assert isinstance(model, dict) | |
for name, value in config.per_module.items(): | |
params = [] | |
for name, value in config.per_module.items(): | |
if name not in model: | |
logger.warning(f"not aware of {name}, skipping") | |
continue | |
params.append( | |
dict( | |
params=model[name].parameters(), | |
**value, | |
) | |
) | |
return load_from_config(config, params=params) | |
# From DaaT merge. Fix here T145981161 | |
class ForwardFilter: | |
"""A module that filters out arguments for the `forward()`.""" | |
# pyre-ignore | |
def __init__(self, module, optional: bool = False) -> None: | |
# pyre-ignore | |
self.module = module | |
# pyre-ignore | |
self.input_names = set(forward_parameter_names(module)) | |
# pyre-ignore | |
def __call__(self, **kwargs): | |
filtered_kwargs = {k: v for k, v in kwargs.items() if k in self.input_names} | |
return self.module(**filtered_kwargs) | |