|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass, field |
|
from typing import Any, Dict, Optional |
|
|
|
from omegaconf import MISSING |
|
|
|
from nemo.core import config |
|
from nemo.core.classes.dataset import DatasetConfig |
|
from nemo.utils import exp_manager |
|
|
|
|
|
@dataclass |
|
class SchedConfig: |
|
name: str = MISSING |
|
min_lr: float = 0.0 |
|
last_epoch: int = -1 |
|
|
|
|
|
@dataclass |
|
class OptimConfig: |
|
name: str = MISSING |
|
sched: Optional[SchedConfig] = None |
|
|
|
|
|
@dataclass |
|
class ModelConfig: |
|
""" |
|
Model component inside ModelPT |
|
""" |
|
|
|
|
|
train_ds: Optional[DatasetConfig] = None |
|
validation_ds: Optional[DatasetConfig] = None |
|
test_ds: Optional[DatasetConfig] = None |
|
optim: Optional[OptimConfig] = None |
|
|
|
|
|
@dataclass |
|
class HydraConfig: |
|
run: Dict[str, Any] = field(default_factory=lambda: {"dir": "."}) |
|
job_logging: Dict[str, Any] = field(default_factory=lambda: {"root": {"handlers": None}}) |
|
|
|
|
|
@dataclass |
|
class NemoConfig: |
|
name: str = MISSING |
|
model: ModelConfig = MISSING |
|
trainer: config.TrainerConfig = config.TrainerConfig( |
|
strategy="ddp", enable_checkpointing=False, logger=False, log_every_n_steps=1, accelerator='gpu' |
|
) |
|
exp_manager: Optional[Any] = exp_manager.ExpManagerConfig() |
|
hydra: HydraConfig = HydraConfig() |
|
|
|
|
|
class ModelConfigBuilder: |
|
def __init__(self, model_cfg: ModelConfig): |
|
""" |
|
Base class for any Model Config Builder. |
|
|
|
A Model Config Builder is a utility class that accepts a ModelConfig dataclass, |
|
and via a set of utility methods (that are implemented by the subclassed ModelConfigBuilder), |
|
builds a finalized ModelConfig that can be supplied to a NemoModel dataclass as |
|
the `model` component. |
|
|
|
Subclasses *must* implement the private method `_finalize_cfg`. |
|
Inside this method, they must update `self.model_cfg` with all interdependent config |
|
options that need to be set (either updated by user explicitly or with their default value). |
|
|
|
The updated model config must then be preserved in `self.model_cfg`. |
|
|
|
Example: |
|
# Create the config builder |
|
config_builder = <subclass>ModelConfigBuilder() |
|
|
|
# Update the components of the config that are modifiable |
|
config_builder.set_X(X) |
|
config_builder.set_Y(Y) |
|
|
|
# Create a "finalized" config dataclass that will contain all the updates |
|
# that were specified by the builder |
|
model_config = config_builder.build() |
|
|
|
# Use model config as is (or further update values), then create a new Model |
|
model = nemo.<domain>.models.<ModelName>Model(cfg=model_config, trainer=Trainer()) |
|
|
|
Supported build methods: |
|
- set_train_ds: All model configs can accept a subclass of `DatasetConfig` as their |
|
training config. Subclasses can override this method to enable auto-complete |
|
by replacing `Optional[DatasetConfig]` with `Optional[<subclass of DatasetConfig>]`. |
|
|
|
- set_validation_ds: All model configs can accept a subclass of `DatasetConfig` as their |
|
validation config. Subclasses can override this method to enable auto-complete |
|
by replacing `Optional[DatasetConfig]` with `Optional[<subclass of DatasetConfig>]`. |
|
|
|
- set_test_ds: All model configs can accept a subclass of `DatasetConfig` as their |
|
test config. Subclasses can override this method to enable auto-complete |
|
by replacing `Optional[DatasetConfig]` with `Optional[<subclass of DatasetConfig>]`. |
|
|
|
- set_optim: A build method that supports changes to the Optimizer (and optionally, |
|
the Scheduler) used for training the model. The function accepts two inputs - |
|
|
|
`cfg`: A subclass of `OptimizerParams` - any OptimizerParams subclass can be used, |
|
in order to select an appropriate Optimizer. Examples: AdamParams. |
|
|
|
`sched_cfg`: A subclass of `SchedulerParams` - any SchedulerParams subclass can be used, |
|
in order to select an appropriate Scheduler. Examples: CosineAnnealingParams. |
|
Note that this argument is optional. |
|
|
|
- build(): The method which should return a "finalized" ModelConfig dataclass. |
|
Subclasses *should* always override this method, and update the signature |
|
of this method with the return type of the Dataclass, so that it enables |
|
autocomplete for the user. |
|
|
|
Example: |
|
def build(self) -> EncDecCTCConfig: |
|
return super().build() |
|
|
|
Any additional build methods must be added by subclasses of ModelConfigBuilder. |
|
|
|
Args: |
|
model_cfg: |
|
""" |
|
self.model_cfg = model_cfg |
|
self.train_ds_cfg = None |
|
self.validation_ds_cfg = None |
|
self.test_ds_cfg = None |
|
self.optim_cfg = None |
|
|
|
def set_train_ds(self, cfg: Optional[DatasetConfig] = None): |
|
self.model_cfg.train_ds = cfg |
|
|
|
def set_validation_ds(self, cfg: Optional[DatasetConfig] = None): |
|
self.model_cfg.validation_ds = cfg |
|
|
|
def set_test_ds(self, cfg: Optional[DatasetConfig] = None): |
|
self.model_cfg.test_ds = cfg |
|
|
|
def set_optim(self, cfg: config.OptimizerParams, sched_cfg: Optional[config.SchedulerParams] = None): |
|
@dataclass |
|
class WrappedOptimConfig(OptimConfig, cfg.__class__): |
|
pass |
|
|
|
|
|
optim_name = cfg.__class__.__name__.replace("Params", "").lower() |
|
wrapped_cfg = WrappedOptimConfig(name=optim_name, sched=None, **vars(cfg)) |
|
|
|
if sched_cfg is not None: |
|
|
|
@dataclass |
|
class WrappedSchedConfig(SchedConfig, sched_cfg.__class__): |
|
pass |
|
|
|
|
|
sched_name = sched_cfg.__class__.__name__.replace("Params", "") |
|
wrapped_sched_cfg = WrappedSchedConfig(name=sched_name, **vars(sched_cfg)) |
|
|
|
wrapped_cfg.sched = wrapped_sched_cfg |
|
|
|
self.model_cfg.optim = wrapped_cfg |
|
|
|
def _finalize_cfg(self): |
|
raise NotImplementedError() |
|
|
|
def build(self) -> ModelConfig: |
|
|
|
self._finalize_cfg() |
|
|
|
return self.model_cfg |
|
|