|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import inspect |
|
from abc import ABC |
|
from dataclasses import dataclass, is_dataclass |
|
from typing import List, Optional, Set, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
from hydra.utils import instantiate |
|
from omegaconf import DictConfig, OmegaConf, open_dict |
|
|
|
from nemo.utils import logging, model_utils |
|
|
|
|
|
ADAPTER_REGISTRY = {} |
|
|
|
|
|
@dataclass |
|
class AdapterRegistryInfo: |
|
base_class: type |
|
adapter_class: type |
|
|
|
|
|
base_class_path: str = "" |
|
adapter_class_path: str = "" |
|
|
|
def __post_init__(self): |
|
self.base_class_path = f'{self.base_class.__module__}.{self.base_class.__name__}' |
|
self.adapter_class_path = f'{self.adapter_class.__module__}.{self.adapter_class.__name__}' |
|
|
|
|
|
def register_adapter(base_class: type, adapter_class: type): |
|
""" |
|
Registers a pair (Base class, Adapter class) into the adapter registry, used for de-referencing. |
|
|
|
Args: |
|
base_class: A Class, which is the base class of the object. |
|
adapter_class: A Class, which is the subclass of the base class, and implements the Adapter mixin methods. |
|
""" |
|
global ADAPTER_REGISTRY |
|
base_class_path = f'{base_class.__module__}.{base_class.__name__}' |
|
adapter_class_path = f'{adapter_class.__module__}.{adapter_class.__name__}' |
|
|
|
|
|
if base_class_path in ADAPTER_REGISTRY: |
|
raise ValueError(f"`{base_class_path}` has already been added to the adapter registry !") |
|
|
|
|
|
if not issubclass(adapter_class, base_class): |
|
raise ValueError(f"`{adapter_class_path}` is not a sub-class of {base_class_path} !") |
|
|
|
|
|
ADAPTER_REGISTRY[base_class_path] = AdapterRegistryInfo(base_class=base_class, adapter_class=adapter_class) |
|
|
|
|
|
base_class._meta_adapter_class = adapter_class |
|
|
|
|
|
adapter_class._meta_base_class = base_class |
|
|
|
|
|
def get_registered_adapter(cls: Union[str, type]) -> Optional[AdapterRegistryInfo]: |
|
""" |
|
Resolves a provided `cls` (whether str path to class, a registered base or an adapter class) |
|
to obtain the metadata for the adapter. |
|
|
|
Args: |
|
cls: Can be a str (absolute path to a class), a base class or an adapter class (which have already |
|
been registered). |
|
|
|
Returns: |
|
A AdapterRegistryInfo object if it could resolve successfully, otherwise None. |
|
""" |
|
global ADAPTER_REGISTRY |
|
if isinstance(cls, str): |
|
cls = model_utils.import_class_by_path(cls) |
|
|
|
|
|
if hasattr(cls, '_meta_base_class'): |
|
cls = cls._meta_base_class |
|
|
|
class_path = f'{cls.__module__}.{cls.__name__}' |
|
|
|
|
|
if class_path in ADAPTER_REGISTRY: |
|
return ADAPTER_REGISTRY[class_path] |
|
|
|
return None |
|
|
|
|
|
def _prepare_default_adapter_config(*, global_key: str, meta_key: str, cfg: DictConfig = None) -> DictConfig: |
|
if cfg is None: |
|
cfg = OmegaConf.create({}) |
|
|
|
with open_dict(cfg): |
|
if global_key not in cfg: |
|
cfg[global_key] = OmegaConf.create({}) |
|
|
|
if meta_key not in cfg[global_key]: |
|
cfg[global_key][meta_key] = OmegaConf.create({}) |
|
|
|
if 'modules' not in cfg[global_key][meta_key]: |
|
cfg[global_key][meta_key]['modules'] = OmegaConf.create({}) |
|
|
|
return cfg |
|
|
|
|
|
class AdapterModuleMixin(ABC): |
|
""" Generic Adapter Mixin that can augment any torch.nn.Module with Adapter module support. |
|
|
|
This mixin class adds a hierarchical way to add any type of Adapter modules to a pre-existing module. |
|
Since Models are inherently also nn.Module, this mixin can be attached to any Model or Module. |
|
This mixin class adds several utility methods which are utilized or overridden as necessary. |
|
|
|
An Adapter module is any Pytorch nn.Module that possess a few properties : |
|
|
|
- It's input and output dimension are the same, while the hidden dimension need not be the same. |
|
- The final layer of the Adapter module is zero-initialized, so that the residual connection to the adapter |
|
yields the original output. |
|
|
|
This mixin adds the following instance variables to the class this inherits it: |
|
|
|
- `adapter_layer`: A torch.nn.ModuleDict(), whose keys are the names of the adapter (globally unique), |
|
and values are the Adapter nn.Module(). |
|
- `adapter_cfg`: A OmegaConf DictConfig object that holds the config of the adapters that are initialized. |
|
- `adapter_name`: A str resolved name which is unique key globally, but more than one modules may share |
|
this name. |
|
- `adapter_global_cfg_key`: A str representing a key in the model config that can be provided by the user. |
|
The value resolves to `global_cfg`, and can be overridden via `model.cfg.adapters.global_cfg.*`. |
|
- `adapter_metadata_cfg_key`: A str representing a key in the model config that is used to preserve the |
|
metadata of the adapter config. |
|
|
|
**Note**: This module is **not** responsible for maintaining its config. Subclasses must ensure config is updated |
|
or preserved as needed. It is the responsibility of the subclasses to propagate the most up to date config to |
|
lower layers. |
|
""" |
|
|
|
adapter_global_cfg_key = "global_cfg" |
|
adapter_metadata_cfg_key = "adapter_meta_cfg" |
|
|
|
def add_adapter(self, name: str, cfg: DictConfig): |
|
""" |
|
Add an Adapter module to this module. |
|
|
|
Args: |
|
name: A globally unique name for the adapter. Will be used to access, enable and disable adapters. |
|
cfg: A DictConfig or Dataclass that contains at the bare minimum `__target__` to instantiate a |
|
new Adapter module. |
|
""" |
|
if not isinstance(cfg, DictConfig): |
|
cfg = DictConfig(cfg) |
|
|
|
adapter_types = self.get_accepted_adapter_types() |
|
_pass_types = False |
|
if len(adapter_types) > 0: |
|
test = model_utils.import_class_by_path(cfg._target_) |
|
for _type in adapter_types: |
|
|
|
if issubclass(test, _type): |
|
_pass_types = True |
|
break |
|
if not _pass_types: |
|
raise ValueError( |
|
f"Config: \n{OmegaConf.to_yaml(cfg)}\n" |
|
f"It creates adapter class {test} \n" |
|
f"that is not in the list of accepted adapter types.\n" |
|
f"Accepted adapters: {[t for t in adapter_types]}" |
|
) |
|
|
|
|
|
if is_dataclass(cfg): |
|
cfg = OmegaConf.structured(cfg) |
|
|
|
if not isinstance(cfg, DictConfig): |
|
cfg = DictConfig(cfg) |
|
|
|
|
|
if not hasattr(self, 'adapter_layer'): |
|
self.adapter_layer = nn.ModuleDict() |
|
|
|
|
|
if not hasattr(self, 'adapter_cfg'): |
|
self.adapter_cfg = OmegaConf.create({}) |
|
|
|
|
|
_, adapter_name = self.resolve_adapter_module_name_(name) |
|
|
|
|
|
self.adapter_name = adapter_name |
|
|
|
|
|
if adapter_name in self.adapter_layer: |
|
raise ValueError( |
|
f"Adapter with name `{name}` already exists ! Adapter names = {list(self.adapter_layer.keys())}" |
|
) |
|
|
|
|
|
if adapter_name == self.adapter_global_cfg_key: |
|
raise ValueError(f"Adapters cannot have the reserved name : `{self.adapter_global_cfg_key}`") |
|
|
|
|
|
with open_dict(cfg), open_dict(self.adapter_cfg): |
|
adapter_enabled = cfg.pop('enabled', True) |
|
self.adapter_layer[adapter_name] = instantiate(cfg) |
|
|
|
cfg['enabled'] = adapter_enabled |
|
self.adapter_cfg[adapter_name] = cfg |
|
|
|
def is_adapter_available(self) -> bool: |
|
""" |
|
Checks if any Adapter module has been instantiated. |
|
|
|
Returns: |
|
bool, determining if any Adapter module has been instantiated. Returns true even if the adapters are |
|
enabled or disabled, false only if no adapters exist. |
|
""" |
|
if hasattr(self, 'adapter_layer'): |
|
return self.adapter_layer is not None and len(self.adapter_layer) > 0 |
|
return False |
|
|
|
def set_enabled_adapters(self, name: Optional[str] = None, enabled: bool = True): |
|
""" |
|
Updated the internal adapter config, determining if an adapter (or all adapters) are either |
|
enabled or disabled. |
|
|
|
A common user pattern would be to disable all adapters (either after adding them, or restoring a model |
|
with pre-existing adapters) and then simply enable one of the adapters. |
|
|
|
.. code:: |
|
|
|
module.set_enabled_adapters(enabled=False) |
|
module.set_enabled_adapters(name=<some adapter name>, enabled=True) |
|
|
|
Args: |
|
name: Optional str. If a str name is given, the config will be updated to the value of `enabled`. |
|
If no name is given, then all adapters will be enabled/disabled. |
|
enabled: Bool, determines if the adapter(s) will be enabled/disabled. |
|
""" |
|
if not self.is_adapter_available(): |
|
raise ValueError("No adapter is available to enable/disable") |
|
|
|
|
|
if name is None: |
|
for key, config in self.adapter_cfg.items(): |
|
|
|
if key == self.adapter_global_cfg_key: |
|
continue |
|
|
|
|
|
self.adapter_cfg[key]['enabled'] = enabled |
|
else: |
|
_, adapter_name = self.resolve_adapter_module_name_(name) |
|
|
|
|
|
if adapter_name == self.adapter_global_cfg_key: |
|
raise ValueError( |
|
f'Cannot set the state of the global config of adapters, ' |
|
f'given name = `{self.adapter_global_cfg_key}`' |
|
) |
|
|
|
|
|
self.adapter_cfg[adapter_name]['enabled'] = enabled |
|
|
|
def get_enabled_adapters(self) -> List[str]: |
|
""" |
|
Returns a list of all enabled adapters names. The names will always be the resolved names, without |
|
module info. |
|
|
|
Returns: |
|
A list of str names of each enabled adapter names(s). |
|
""" |
|
if not self.is_adapter_available(): |
|
return [] |
|
|
|
|
|
available_module_names = set([]) |
|
if hasattr(self, 'adapter_layer'): |
|
available_module_names.update(list(self.adapter_layer.keys())) |
|
|
|
|
|
adapter_types = self.get_accepted_adapter_types() |
|
|
|
enabled_adapters = [] |
|
for name, config in self.adapter_cfg.items(): |
|
|
|
if name == self.adapter_global_cfg_key: |
|
continue |
|
|
|
|
|
if name in available_module_names and self.adapter_cfg[name]['enabled']: |
|
|
|
if len(adapter_types) > 0: |
|
module = self.get_adapter_module(name) |
|
|
|
for adapter_type in adapter_types: |
|
if isinstance(module, adapter_type): |
|
enabled_adapters.append(name) |
|
break |
|
|
|
else: |
|
|
|
enabled_adapters.append(name) |
|
|
|
return enabled_adapters |
|
|
|
|
|
|
|
def get_adapter_module(self, name: str): |
|
""" |
|
Gets an adapter module by name if possible, otherwise returns None. |
|
|
|
Args: |
|
name: A str name (resolved or not) corresponding to an Adapter. |
|
|
|
Returns: |
|
An nn.Module if the name could be resolved and matched, otherwise None/ |
|
""" |
|
_, name = self.resolve_adapter_module_name_(name) |
|
|
|
if hasattr(self, "adapter_layer"): |
|
return self.adapter_layer[name] if name in self.adapter_layer else None |
|
return None |
|
|
|
def set_accepted_adapter_types(self, adapter_types: List[Union[type, str]]) -> None: |
|
""" |
|
The module with this mixin can define a list of adapter names that it will accept. |
|
This method should be called in the modules init method and set the adapter names the module will expect to be added. |
|
|
|
Args: |
|
adapter_types: A list of str paths that correspond to classes. The class paths will be instantiated to |
|
ensure that the class path is correct. |
|
""" |
|
|
|
types = [] |
|
for s in adapter_types: |
|
if inspect.isclass(s): |
|
if not issubclass(s, nn.Module): |
|
raise ValueError(f"Attempted to add class ({s}) but is not a subclass of torch.nn.Module") |
|
|
|
types.append(s) |
|
else: |
|
types.append(model_utils.import_class_by_path(s)) |
|
|
|
self._accepted_adapter_types = set(types) |
|
|
|
def get_accepted_adapter_types(self,) -> Set[type]: |
|
""" |
|
Utility function to get the set of all classes that are accepted by the module. |
|
|
|
Returns: |
|
Returns the set of accepted adapter types as classes, otherwise an empty set. |
|
""" |
|
if hasattr(self, '_accepted_adapter_types'): |
|
return self._accepted_adapter_types |
|
else: |
|
return set([]) |
|
|
|
def unfreeze_enabled_adapters(self, freeze_batchnorm: bool = True) -> None: |
|
""" |
|
Utility method to unfreeze only the enabled Adapter module(s). |
|
|
|
A common user pattern is to freeze all the modules (including all the adapters), and then |
|
unfreeze just the required adapters. |
|
|
|
.. code:: |
|
|
|
module.freeze() # only available to nemo.core.NeuralModule ! |
|
module.unfreeze_enabled_adapters() |
|
|
|
Args: |
|
freeze_batchnorm: An optional (and recommended) practice of freezing the updates to the moving average |
|
buffers of any and all BatchNorm*D layers. This is necessary to ensure that disabling all adapters |
|
will precisely yield the original (base) model's outputs. |
|
""" |
|
if freeze_batchnorm: |
|
for mname, module in self.named_modules(): |
|
if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): |
|
if hasattr(module, 'weight'): |
|
module.weight.requires_grad_(False) |
|
if hasattr(module, 'bias'): |
|
module.bias.requires_grad_(False) |
|
module.eval() |
|
module.track_running_stats = False |
|
|
|
logging.info(f"Froze module {mname}: {module}") |
|
|
|
adapter_names = set([]) |
|
for module in self.modules(): |
|
if hasattr(module, 'adapter_layer') and module.is_adapter_available(): |
|
for name, config in self.adapter_cfg.items(): |
|
|
|
if name == self.adapter_global_cfg_key: |
|
continue |
|
|
|
|
|
if self.adapter_cfg[name]['enabled'] and name in module.adapter_layer: |
|
|
|
module.adapter_layer[name].train() |
|
|
|
|
|
for pname, param in module.adapter_layer[name].named_parameters(): |
|
param.requires_grad_(True) |
|
|
|
|
|
for mname, module_ in module.adapter_layer[name].named_modules(): |
|
if isinstance(module_, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): |
|
module_.track_running_stats = ( |
|
True |
|
) |
|
logging.info(f"Unfroze adapter module {mname}: {module_}") |
|
|
|
adapter_names.add(name) |
|
|
|
for name in adapter_names: |
|
logging.info(f"Unfrozen adapter : {name}") |
|
|
|
def forward_enabled_adapters(self, input: 'torch.Tensor'): |
|
""" |
|
Forward's all active adapters one by one with the provided input, and chaining the outputs of each |
|
adapter layer to the next. |
|
|
|
Utilizes the implicit merge strategy of each adapter when computing the adapter's output, and |
|
how that output will be merged back with the original input. |
|
|
|
**Note**: |
|
|
|
Args: |
|
input: The output tensor of the calling module is the input to the first adapter, whose output |
|
is then chained to the next adapter until all adapters are consumed. |
|
|
|
Returns: |
|
The result tensor, after all active adapters have finished their forward passes. |
|
""" |
|
enabled_adapters = self.get_enabled_adapters() |
|
for adapter_name in enabled_adapters: |
|
adapter_module = self.adapter_layer[adapter_name] |
|
|
|
if hasattr(adapter_module, 'adapter_strategy'): |
|
strategy = ( |
|
adapter_module.adapter_strategy |
|
) |
|
else: |
|
raise AttributeError( |
|
f"Adapter module `{adapter_name}` does not set the value `adapter_strategy` ! " |
|
f"Please set the value of the adapter's strategy with the class " |
|
f"{adapter_module.__class__.__module}.{adapter_module.__class__.__name__}." |
|
) |
|
|
|
|
|
input = self.forward_single_enabled_adapter_( |
|
input, adapter_module, adapter_name=adapter_name, adapter_strategy=strategy |
|
) |
|
|
|
return input |
|
|
|
|
|
|
|
def resolve_adapter_module_name_(self, name: str) -> Tuple[str, str]: |
|
""" |
|
Utility method to resolve a given global/module adapter name to its components. |
|
Always returns a tuple representing (module_name, adapter_name). ":" is used as the |
|
delimiter for denoting the module name vs the adapter name. |
|
|
|
Will attempt to also resolve a given adapter_name alone back to (module_name, adapter_name) |
|
if the metadata config exists for access. |
|
|
|
Args: |
|
name: A global adapter, or a module adapter name (with structure module_name:adapter_name). |
|
|
|
Returns: |
|
A tuple representing (module_name, adapter_name). If a global adapter is provided, |
|
module_name is set to ''. |
|
""" |
|
|
|
if ':' in name: |
|
splits = name.split(":") |
|
module_name = splits[0] |
|
adapter_name = ":".join(splits[1:]) |
|
return (module_name, adapter_name) |
|
else: |
|
|
|
module_name = '' |
|
|
|
|
|
|
|
if hasattr(self, 'adapter_cfg') and self.adapter_cfg is not None: |
|
cfg = self.adapter_cfg.get(self.adapter_global_cfg_key, {}) |
|
cfg = cfg.get(self.adapter_metadata_cfg_key, {}) |
|
cfg = cfg.get('modules', {}) |
|
|
|
|
|
module_name = cfg.get(name, '') |
|
|
|
|
|
|
|
return (module_name, name) |
|
|
|
def forward_single_enabled_adapter_( |
|
self, |
|
input: torch.Tensor, |
|
adapter_module: torch.nn.Module, |
|
*, |
|
adapter_name: str, |
|
adapter_strategy: 'nemo.core.classes.mixins.adapter_mixin_strategies.AbstractAdapterStrategy', |
|
): |
|
""" |
|
Perform the forward step of a single adapter module on some input data. |
|
|
|
**Note**: Subclasses can override this method to accommodate more complicate adapter forward steps. |
|
|
|
Args: |
|
input: input: The output tensor of the calling module is the input to the first adapter, whose output |
|
is then chained to the next adapter until all adapters are consumed. |
|
adapter_module: The adapter module that is currently required to perform the forward pass. |
|
adapter_name: The resolved name of the adapter that is undergoing the current forward pass. |
|
adapter_strategy: A subclass of `AbstractAdapterStrategy`, that determines how the |
|
output of the adapter should be merged with the input, or if it should be merged at all. |
|
|
|
Returns: |
|
The result tensor, after the current active adapter has finished its forward pass. |
|
""" |
|
|
|
output = adapter_strategy(input, adapter_module, module=self) |
|
return output |
|
|
|
|
|
class AdapterModelPTMixin(AdapterModuleMixin): |
|
""" Adapter Mixin that can augment a ModelPT subclass with Adapter support. |
|
|
|
This mixin class should be used only with a top level ModelPT subclass. |
|
This mixin class adds several utility methods which should be subclassed and overriden to |
|
propagated to the submodules as necessary. |
|
|
|
An Adapter module is any Pytorch nn.Module that possess a few properties : |
|
|
|
- It's input and output dimension are the same, while the hidden dimension need not be the same. |
|
- The final layer of the Adapter module is zero-initialized, so that the residual connection to the adapter |
|
yields the original output. |
|
|
|
This mixin adds the following instance variables to the class this inherits it: |
|
|
|
- `adapter_layer`: A torch.nn.ModuleDict(), whose keys are the names of the adapter (globally unique), |
|
and values are the Adapter nn.Module(). |
|
- `adapter_cfg`: A OmegaConf DictConfig object that holds the config of the adapters that are initialized. |
|
- `adapter_global_cfg_key`: A str representing a key in the model config that can be provided by the user. |
|
The value resolves to `global_cfg`, and can be overridden via `model.cfg.adapters.global_cfg.*`. |
|
|
|
.. note:: |
|
|
|
This module **is** responsible for maintaining its config. At the ModelPT level, it will access and |
|
write Adapter config information to `self.cfg.adapters`. |
|
""" |
|
|
|
def setup_adapters(self): |
|
""" |
|
Utility method that is called in the ASR ModelPT-implementation constructor, so as to restore any |
|
adapters that were previously added. |
|
|
|
Should be overriden by the subclass for additional setup steps as required. |
|
|
|
This method should be called just once at constructor time. |
|
""" |
|
|
|
if 'adapters' in self.cfg: |
|
|
|
self.update_adapter_cfg(self.cfg.adapters) |
|
|
|
|
|
for adapter_name, adapter_cfg in self.cfg.adapters.items(): |
|
|
|
if adapter_name == self.adapter_global_cfg_key: |
|
continue |
|
|
|
|
|
|
|
self._restoring_adapters = True |
|
|
|
|
|
self.add_adapter(name=adapter_name, cfg=adapter_cfg) |
|
|
|
|
|
del self._restoring_adapters |
|
|
|
|
|
module_name, adapter_name = self.resolve_adapter_module_name_(adapter_name) |
|
|
|
if module_name != '': |
|
full_adapter_name = f'{module_name}:{adapter_name}' |
|
else: |
|
full_adapter_name = adapter_name |
|
|
|
logging.info( |
|
f"Finished setup of adapter : '{full_adapter_name}'. Enabled: {adapter_cfg.get('enabled', True)}." |
|
) |
|
|
|
def add_adapter(self, name: str, cfg: DictConfig): |
|
""" |
|
Add an Adapter module to this model. |
|
|
|
Should be overridden by subclass and super() call must be used - this will setup the config. |
|
After calling super(), forward this call to modules that implement the mixin. |
|
|
|
Args: |
|
name: A globally unique name for the adapter. Will be used to access, enable and disable adapters. |
|
cfg: A DictConfig that contains at the bare minimum `__target__` to instantiate a new Adapter module. |
|
""" |
|
|
|
if is_dataclass(cfg): |
|
cfg = OmegaConf.structured(cfg) |
|
|
|
if not isinstance(cfg, DictConfig): |
|
cfg = DictConfig(cfg) |
|
|
|
|
|
module_name, adapter_name = self.resolve_adapter_module_name_(name) |
|
|
|
|
|
with open_dict(cfg), open_dict(self.cfg): |
|
|
|
if 'adapters' not in self.cfg: |
|
self.cfg.adapters = OmegaConf.create({}) |
|
|
|
self.cfg.adapters = _prepare_default_adapter_config( |
|
global_key=self.adapter_global_cfg_key, meta_key=self.adapter_metadata_cfg_key, cfg=self.cfg.adapters, |
|
) |
|
|
|
|
|
if hasattr(self, '_restoring_adapters') and self._restoring_adapters is not True: |
|
if adapter_name in self.cfg.adapters: |
|
raise ValueError(f"Attempting to add multiple adapters with the same name ({adapter_name}) !") |
|
|
|
|
|
gcfg = self.adapter_global_cfg_key |
|
mcfg = self.adapter_metadata_cfg_key |
|
self.cfg.adapters[gcfg][mcfg]['modules'][adapter_name] = module_name |
|
|
|
|
|
if 'enabled' not in cfg: |
|
cfg['enabled'] = True |
|
|
|
|
|
self.cfg.adapters[adapter_name] = OmegaConf.create(cfg) |
|
|
|
|
|
self.update_adapter_cfg(self.cfg.adapters) |
|
|
|
self.check_valid_model_with_adapter_support_() |
|
|
|
def is_adapter_available(self) -> bool: |
|
""" |
|
Checks if any Adapter module has been instantiated. |
|
|
|
Should be overridden by the subclass. |
|
|
|
Returns: |
|
bool, determining if any Adapter module has been instantiated. Returns true even if the adapters are |
|
enabled or disabled, false only if no adapters exist. |
|
""" |
|
self.check_valid_model_with_adapter_support_() |
|
|
|
if 'adapters' in self.cfg: |
|
self.update_adapter_cfg(self.cfg.adapters) |
|
|
|
return 'adapters' in self.cfg and len(self.get_enabled_adapters()) > 0 |
|
|
|
def set_enabled_adapters(self, name: Optional[str] = None, enabled: bool = True): |
|
""" |
|
Updated the internal adapter config, determining if an adapter (or all adapters) are either |
|
enabled or disabled. |
|
|
|
A common user pattern would be to disable all adapters (either after adding them, or restoring a model |
|
with pre-existing adapters) and then simply enable one of the adapters. |
|
|
|
Should be overridden by subclass and super() call must be used - this will setup the config. |
|
After calling super(), forward this call to modules that implement the mixin. |
|
|
|
.. code:: |
|
|
|
model.set_enabled_adapters(enabled=False) |
|
model.set_enabled_adapters(name=<some adapter name>, enabled=True) |
|
|
|
Args: |
|
name: Optional str. If a str name is given, the config will be updated to the value of `enabled`. |
|
If no name is given, then all adapters will be enabled/disabled. |
|
enabled: Bool, determines if the adapter(s) will be enabled/disabled. |
|
""" |
|
self.check_valid_model_with_adapter_support_() |
|
|
|
|
|
with open_dict(self.cfg.adapters): |
|
|
|
if name is None: |
|
for key in self.cfg.adapters.keys(): |
|
|
|
if key == self.adapter_global_cfg_key: |
|
continue |
|
|
|
self.cfg.adapters[key]['enabled'] = enabled |
|
logging.info(f"Setting adapter '{key}' status : Enabled = {enabled}") |
|
|
|
else: |
|
|
|
module_name, adapter_name = self.resolve_adapter_module_name_(name) |
|
|
|
|
|
if adapter_name == self.adapter_global_cfg_key: |
|
raise ValueError( |
|
f'Cannot set the state of the global config of adapters, ' |
|
f'given name = `{self.adapter_global_cfg_key}`' |
|
) |
|
|
|
|
|
self.cfg.adapters[adapter_name]['enabled'] = enabled |
|
logging.info(f"Setting adapter '{name}' status : Enabled = {enabled}") |
|
|
|
self.update_adapter_cfg(self.cfg.adapters) |
|
|
|
def get_enabled_adapters(self) -> List[str]: |
|
""" |
|
Returns a list of all enabled adapters. |
|
|
|
Should be implemented by the subclass. |
|
|
|
Returns: |
|
A list of str names of each enabled adapter(s). |
|
""" |
|
self.check_valid_model_with_adapter_support_() |
|
|
|
if 'adapters' in self.cfg: |
|
self.update_adapter_cfg(self.cfg.adapters) |
|
return [] |
|
|
|
def check_valid_model_with_adapter_support_(self): |
|
""" |
|
Utility method to test if the subclass of this mixin is an appropriate subclass of ModelPT itself. |
|
|
|
Should be implemented by the subclass. |
|
""" |
|
pass |
|
|
|
def save_adapters(self, filepath: str, name: str = None): |
|
""" |
|
Utility method that saves only the adapter module(s), and not the entire model itself. |
|
This allows the sharing of adapters which are often just a fraction of the size of the full model, |
|
enabling easier deliver. |
|
|
|
Note: The saved file is a pytorch compatible pickle file, containing the state dicts of the adapter(s), |
|
as well as a binary representation of the adapter config. |
|
|
|
Args: |
|
filepath: A str filepath where the .pt file that will contain the adapter state dict. |
|
name: Optional name of the adapter that will be saved to this file. If None is passed, |
|
all adapters will be saved to the file. The name can be either the global name (adapter_name), |
|
or the module level name (module:adapter_name). |
|
""" |
|
if not hasattr(self, 'cfg') or 'adapters' not in self.cfg: |
|
raise AttributeError("No adapters have been added to this model, so no adapters can be saved.") |
|
|
|
output_dict = {} |
|
|
|
|
|
if isinstance(name, str): |
|
name = [name] |
|
|
|
if name is None: |
|
name = self.cfg.adapters.keys() |
|
|
|
|
|
if not hasattr(self.cfg, 'adapters'): |
|
raise ValueError( |
|
"The model has no adapter config, therefore it cannot save any adapter. " |
|
"Please first add one or more adapters to generate the config." |
|
) |
|
|
|
|
|
for adapter_name in name: |
|
if adapter_name != self.adapter_global_cfg_key: |
|
|
|
module_name, adapter_name = self.resolve_adapter_module_name_(adapter_name) |
|
|
|
|
|
if module_name == '': |
|
key = adapter_name |
|
else: |
|
key = f'{module_name}:{adapter_name}' |
|
output_dict[key] = [] |
|
|
|
|
|
|
|
|
|
|
|
for module in self.modules(): |
|
if isinstance(module, AdapterModuleMixin): |
|
|
|
|
|
|
|
|
|
adapter_module = module.get_adapter_module(adapter_name) |
|
if adapter_module is not None: |
|
|
|
|
|
|
|
|
|
|
|
adapter_state_dict = module.adapter_layer.state_dict() |
|
state_dict = {} |
|
for k, v in adapter_state_dict.items(): |
|
if adapter_name in k: |
|
state_dict[k] = v |
|
|
|
output_dict[key].append(state_dict) |
|
|
|
|
|
output_dict['__cfg__'] = self.cfg.adapters |
|
|
|
|
|
torch.save(output_dict, filepath) |
|
|
|
def load_adapters(self, filepath: str, name: str = None, map_location: str = None, strict: bool = True): |
|
""" |
|
Utility method that restores only the adapter module(s), and not the entire model itself. |
|
This allows the sharing of adapters which are often just a fraction of the size of the full model, |
|
enabling easier deliver. |
|
|
|
Note: During restoration, assumes that the model does not currently already have an adapter with |
|
the name (if provided), or any adapter that shares a name with the state dict's modules |
|
(if name is not provided). This is to ensure that each adapter name is globally unique |
|
in a model. |
|
|
|
Args: |
|
filepath: Filepath of the .pt file. |
|
name: Optional name of the adapter that will be saved to this file. If None is passed, |
|
all adapters will be saved to the file. The name must be either the global name (adapter_name), |
|
or the module level name (module:adapter_name), whichever exactly matches the state dict. |
|
map_location: Pytorch flag, where to place the adapter(s) state dict(s). |
|
strict: Pytorch flag, whether to load the weights of the adapter(s) strictly or not. |
|
""" |
|
|
|
if map_location is None: |
|
if torch.cuda.is_available(): |
|
map_location = 'cuda' |
|
else: |
|
map_location = 'cpu' |
|
|
|
|
|
state_dict = torch.load(filepath, map_location=map_location) |
|
config = state_dict.pop('__cfg__') |
|
|
|
|
|
if isinstance(name, str): |
|
name = [name] |
|
|
|
if name is None: |
|
name = list(config.keys()) |
|
|
|
|
|
for module_adapter_name in name: |
|
|
|
internal_adapter_cfg = None |
|
if hasattr(self, 'adapter_cfg') and self.adapter_cfg is not None: |
|
internal_adapter_cfg = self.adapter_cfg |
|
|
|
|
|
self.adapter_cfg = config |
|
|
|
|
|
module_name, adapter_name = self.resolve_adapter_module_name_(module_adapter_name) |
|
adapter_cfg = config[adapter_name] |
|
|
|
|
|
if module_name == '': |
|
module_adapter_name = adapter_name |
|
else: |
|
module_adapter_name = f'{module_name}:{adapter_name}' |
|
|
|
|
|
self.adapter_cfg = internal_adapter_cfg |
|
|
|
|
|
if adapter_name == self.adapter_global_cfg_key: |
|
continue |
|
|
|
|
|
try: |
|
adapter_state = state_dict[module_adapter_name] |
|
except KeyError: |
|
all_keys = list(state_dict.keys()) |
|
raise KeyError( |
|
f"Requested to load adapter with name `{module_adapter_name}`, but could not " |
|
f"the adapter in the state dict. \nAvailable adapter names in state dict are: " |
|
f"{all_keys}" |
|
) |
|
|
|
|
|
self.add_adapter(name=module_adapter_name, cfg=adapter_cfg) |
|
|
|
|
|
|
|
|
|
modules_to_load = [] |
|
for module in self.modules(): |
|
if isinstance(module, AdapterModuleMixin): |
|
adapter_module = module.get_adapter_module(adapter_name) |
|
if adapter_module is not None: |
|
modules_to_load.append(adapter_module) |
|
|
|
|
|
if len(adapter_state) != len(modules_to_load): |
|
raise ValueError( |
|
f"The number of adapters in current model ({len(modules_to_load)}) does not " |
|
f"match the number of modules in the state dict for adapter `{adapter_name}`: " |
|
f"({len(adapter_state)})" |
|
) |
|
|
|
|
|
for state, module in zip(adapter_state, modules_to_load): |
|
|
|
|
|
|
|
|
|
sub_dict = {} |
|
for k, v in state.items(): |
|
if adapter_name in k: |
|
k_ = k.replace(f"{adapter_name}.", "") |
|
sub_dict[k_] = v |
|
|
|
module.load_state_dict(sub_dict, strict=strict) |
|
del sub_dict |
|
|
|
|
|
del adapter_state, modules_to_load |
|
|
|
def update_adapter_cfg(self, cfg: DictConfig): |
|
""" |
|
Utility method to recursively update all of the Adapter module configs with the provided config. |
|
|
|
.. note:: |
|
|
|
It is not a (deep)copy, but a reference copy. Changes made to the config will be reflected to |
|
adapter submodules, but it is still encouraged to explicitly update the adapter_cfg using this method. |
|
|
|
Args: |
|
cfg: DictConfig containing the value of `model.cfg.adapters`. |
|
""" |
|
for module in self.modules(): |
|
if isinstance(module, AdapterModuleMixin): |
|
module.adapter_cfg = cfg |
|
|
|
@property |
|
def adapter_module_names(self) -> List[str]: |
|
""" |
|
List of valid adapter modules that are supported by the model. |
|
|
|
**Note**: Subclasses should override this property and return a list of str names, of all the modules |
|
that they support, which will enable users to determine where to place the adapter modules. |
|
|
|
Returns: |
|
A list of str, one for each of the adapter modules that are supported. By default, the subclass |
|
should support the "global adapter" (''). |
|
""" |
|
return [''] |
|
|