File size: 12,528 Bytes
7934b29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 |
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import inspect
from dataclasses import is_dataclass
from typing import Dict, List, Optional
from nemo.utils import logging
# TODO @blisc: Perhaps refactor instead of import guarding
_HAS_HYDRA = True
try:
from omegaconf import DictConfig, OmegaConf, open_dict
except ModuleNotFoundError:
_HAS_HYDRA = False
def update_model_config(
model_cls: 'nemo.core.config.modelPT.NemoConfig', update_cfg: 'DictConfig', drop_missing_subconfigs: bool = True
):
"""
Helper class that updates the default values of a ModelPT config class with the values
in a DictConfig that mirrors the structure of the config class.
Assumes the `update_cfg` is a DictConfig (either generated manually, via hydra or instantiated via yaml/model.cfg).
This update_cfg is then used to override the default values preset inside the ModelPT config class.
If `drop_missing_subconfigs` is set, the certain sub-configs of the ModelPT config class will be removed, if
they are not found in the mirrored `update_cfg`. The following sub-configs are subject to potential removal:
- `train_ds`
- `validation_ds`
- `test_ds`
- `optim` + nested `sched`.
Args:
model_cls: A subclass of NemoConfig, that details in entirety all of the parameters that constitute
the NeMo Model.
update_cfg: A DictConfig that mirrors the structure of the NemoConfig data class. Used to update the
default values of the config class.
drop_missing_subconfigs: Bool which determins whether to drop certain sub-configs from the NemoConfig
class, if the corresponding sub-config is missing from `update_cfg`.
Returns:
A DictConfig with updated values that can be used to instantiate the NeMo Model along with supporting
infrastructure.
"""
if not _HAS_HYDRA:
logging.error("This function requires Hydra/Omegaconf and it was not installed.")
exit(1)
if not (is_dataclass(model_cls) or isinstance(model_cls, DictConfig)):
raise ValueError("`model_cfg` must be a dataclass or a structured OmegaConf object")
if not isinstance(update_cfg, DictConfig):
update_cfg = OmegaConf.create(update_cfg)
if is_dataclass(model_cls):
model_cls = OmegaConf.structured(model_cls)
# Update optional configs
model_cls = _update_subconfig(
model_cls, update_cfg, subconfig_key='train_ds', drop_missing_subconfigs=drop_missing_subconfigs
)
model_cls = _update_subconfig(
model_cls, update_cfg, subconfig_key='validation_ds', drop_missing_subconfigs=drop_missing_subconfigs
)
model_cls = _update_subconfig(
model_cls, update_cfg, subconfig_key='test_ds', drop_missing_subconfigs=drop_missing_subconfigs
)
model_cls = _update_subconfig(
model_cls, update_cfg, subconfig_key='optim', drop_missing_subconfigs=drop_missing_subconfigs
)
# Add optim and sched additional keys to model cls
model_cls = _add_subconfig_keys(model_cls, update_cfg, subconfig_key='optim')
# Perform full merge of model config class and update config
# Remove ModelPT artifact `target`
if 'target' in update_cfg.model:
# Assume artifact from ModelPT and pop
if 'target' not in model_cls.model:
with open_dict(update_cfg.model):
update_cfg.model.pop('target')
# Remove ModelPT artifact `nemo_version`
if 'nemo_version' in update_cfg.model:
# Assume artifact from ModelPT and pop
if 'nemo_version' not in model_cls.model:
with open_dict(update_cfg.model):
update_cfg.model.pop('nemo_version')
model_cfg = OmegaConf.merge(model_cls, update_cfg)
return model_cfg
def _update_subconfig(
model_cfg: 'DictConfig', update_cfg: 'DictConfig', subconfig_key: str, drop_missing_subconfigs: bool
):
"""
Updates the NemoConfig DictConfig such that:
1) If the sub-config key exists in the `update_cfg`, but does not exist in ModelPT config:
- Add the sub-config from update_cfg to ModelPT config
2) If the sub-config key does not exist in `update_cfg`, but exists in ModelPT config:
- Remove the sub-config from the ModelPT config; iff the `drop_missing_subconfigs` flag is set.
Args:
model_cfg: A DictConfig instantiated from the NemoConfig subclass.
update_cfg: A DictConfig that mirrors the structure of `model_cfg`, used to update its default values.
subconfig_key: A str key used to check and update the sub-config.
drop_missing_subconfigs: A bool flag, whether to allow deletion of the NemoConfig sub-config,
if its mirror sub-config does not exist in the `update_cfg`.
Returns:
The updated DictConfig for the NemoConfig
"""
if not _HAS_HYDRA:
logging.error("This function requires Hydra/Omegaconf and it was not installed.")
exit(1)
with open_dict(model_cfg.model):
# If update config has the key, but model cfg doesnt have the key
# Add the update cfg subconfig to the model cfg
if subconfig_key in update_cfg.model and subconfig_key not in model_cfg.model:
model_cfg.model[subconfig_key] = update_cfg.model[subconfig_key]
# If update config does not the key, but model cfg has the key
# Remove the model cfg subconfig in order to match layout of update cfg
if subconfig_key not in update_cfg.model and subconfig_key in model_cfg.model:
if drop_missing_subconfigs:
model_cfg.model.pop(subconfig_key)
return model_cfg
def _add_subconfig_keys(model_cfg: 'DictConfig', update_cfg: 'DictConfig', subconfig_key: str):
"""
For certain sub-configs, the default values specified by the NemoConfig class is insufficient.
In order to support every potential value in the merge between the `update_cfg`, it would require
explicit definition of all possible cases.
An example of such a case is Optimizers, and their equivalent Schedulers. All optimizers share a few basic
details - such as name and lr, but almost all require additional parameters - such as weight decay.
It is impractical to create a config for every single optimizer + every single scheduler combination.
In such a case, we perform a dual merge. The Optim and Sched Dataclass contain the bare minimum essential
components. The extra values are provided via update_cfg.
In order to enable the merge, we first need to update the update sub-config to incorporate the keys,
with dummy temporary values (merge update config with model config). This is done on a copy of the
update sub-config, as the actual override values might be overriden by the NemoConfig defaults.
Then we perform a merge of this temporary sub-config with the actual override config in a later step
(merge model_cfg with original update_cfg, done outside this function).
Args:
model_cfg: A DictConfig instantiated from the NemoConfig subclass.
update_cfg: A DictConfig that mirrors the structure of `model_cfg`, used to update its default values.
subconfig_key: A str key used to check and update the sub-config.
Returns:
A ModelPT DictConfig with additional keys added to the sub-config.
"""
if not _HAS_HYDRA:
logging.error("This function requires Hydra/Omegaconf and it was not installed.")
exit(1)
with open_dict(model_cfg.model):
# Create copy of original model sub config
if subconfig_key in update_cfg.model:
if subconfig_key not in model_cfg.model:
# create the key as a placeholder
model_cfg.model[subconfig_key] = None
subconfig = copy.deepcopy(model_cfg.model[subconfig_key])
update_subconfig = copy.deepcopy(update_cfg.model[subconfig_key])
# Add the keys and update temporary values, will be updated during full merge
subconfig = OmegaConf.merge(update_subconfig, subconfig)
# Update sub config
model_cfg.model[subconfig_key] = subconfig
return model_cfg
def assert_dataclass_signature_match(
cls: 'class_type',
datacls: 'dataclass',
ignore_args: Optional[List[str]] = None,
remap_args: Optional[Dict[str, str]] = None,
):
"""
Analyses the signature of a provided class and its respective data class,
asserting that the dataclass signature matches the class __init__ signature.
Note:
This is not a value based check. This function only checks if all argument
names exist on both class and dataclass and logs mismatches.
Args:
cls: Any class type - but not an instance of a class. Pass type(x) where x is an instance
if class type is not easily available.
datacls: A corresponding dataclass for the above class.
ignore_args: (Optional) A list of string argument names which are forcibly ignored,
even if mismatched in the signature. Useful when a dataclass is a superset of the
arguments of a class.
remap_args: (Optional) A dictionary, mapping an argument name that exists (in either the
class or its dataclass), to another name. Useful when argument names are mismatched between
a class and its dataclass due to indirect instantiation via a helper method.
Returns:
A tuple containing information about the analysis:
1) A bool value which is True if the signatures matched exactly / after ignoring values.
False otherwise.
2) A set of arguments names that exist in the class, but *do not* exist in the dataclass.
If exact signature match occurs, this will be None instead.
3) A set of argument names that exist in the data class, but *do not* exist in the class itself.
If exact signature match occurs, this will be None instead.
"""
class_sig = inspect.signature(cls.__init__)
class_params = dict(**class_sig.parameters)
class_params.pop('self')
dataclass_sig = inspect.signature(datacls)
dataclass_params = dict(**dataclass_sig.parameters)
dataclass_params.pop("_target_", None)
class_params = set(class_params.keys())
dataclass_params = set(dataclass_params.keys())
if remap_args is not None:
for original_arg, new_arg in remap_args.items():
if original_arg in class_params:
class_params.remove(original_arg)
class_params.add(new_arg)
logging.info(f"Remapped {original_arg} -> {new_arg} in {cls.__name__}")
if original_arg in dataclass_params:
dataclass_params.remove(original_arg)
dataclass_params.add(new_arg)
logging.info(f"Remapped {original_arg} -> {new_arg} in {datacls.__name__}")
if ignore_args is not None:
ignore_args = set(ignore_args)
class_params = class_params - ignore_args
dataclass_params = dataclass_params - ignore_args
logging.info(f"Removing ignored arguments - {ignore_args}")
intersection = set.intersection(class_params, dataclass_params)
subset_cls = class_params - intersection
subset_datacls = dataclass_params - intersection
if (len(class_params) != len(dataclass_params)) or len(subset_cls) > 0 or len(subset_datacls) > 0:
logging.error(f"Class {cls.__name__} arguments do not match " f"Dataclass {datacls.__name__}!")
if len(subset_cls) > 0:
logging.error(f"Class {cls.__name__} has additional arguments :\n" f"{subset_cls}")
if len(subset_datacls):
logging.error(f"Dataclass {datacls.__name__} has additional arguments :\n{subset_datacls}")
return False, subset_cls, subset_datacls
else:
return True, None, None
|