|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
from typing import Any, Optional |
|
|
|
from hydra.core.config_store import ConfigStore |
|
|
|
__all__ = ['TrainerConfig'] |
|
|
|
|
|
cs = ConfigStore.instance() |
|
|
|
|
|
@dataclass |
|
class TrainerConfig: |
|
""" |
|
Configuration of PyTorch Lightning Trainer. |
|
It is not derived from Config as it is not a NeMo object (and in particular it doesn't need a name). |
|
..warning: |
|
Picked just few params of the PTL trainer for now. This needs to be discussed. |
|
..note: |
|
For the details on the function/meanings of the arguments, please refer to: |
|
https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html |
|
""" |
|
|
|
logger: Any = True |
|
callbacks: Optional[Any] = None |
|
default_root_dir: Optional[str] = None |
|
gradient_clip_val: float = 0 |
|
num_nodes: int = 1 |
|
gpus: Optional[Any] = None |
|
auto_select_gpus: bool = False |
|
tpu_cores: Optional[Any] = None |
|
enable_progress_bar: bool = True |
|
overfit_batches: Any = 0.0 |
|
track_grad_norm: Any = -1 |
|
check_val_every_n_epoch: int = 1 |
|
fast_dev_run: bool = False |
|
accumulate_grad_batches: Any = 1 |
|
max_epochs: int = 1000 |
|
min_epochs: int = 1 |
|
max_steps: Optional[int] = -1 |
|
min_steps: Optional[int] = None |
|
limit_train_batches: Any = 1.0 |
|
limit_val_batches: Any = 1.0 |
|
limit_test_batches: Any = 1.0 |
|
val_check_interval: Any = 1.0 |
|
log_every_n_steps: int = 50 |
|
accelerator: Optional[str] = None |
|
sync_batchnorm: bool = False |
|
precision: Any = 32 |
|
num_sanity_val_steps: int = 2 |
|
resume_from_checkpoint: Optional[str] = None |
|
profiler: Optional[Any] = None |
|
benchmark: bool = False |
|
deterministic: bool = False |
|
auto_lr_find: Any = False |
|
replace_sampler_ddp: bool = True |
|
detect_anomaly: bool = False |
|
auto_scale_batch_size: Any = False |
|
amp_backend: str = 'native' |
|
amp_level: Optional[str] = None |
|
plugins: Optional[Any] = None |
|
move_metrics_to_cpu: bool = False |
|
multiple_trainloader_mode: str = 'max_size_cycle' |
|
limit_predict_batches: float = 1.0 |
|
gradient_clip_algorithm: str = 'norm' |
|
max_time: Optional[Any] = None |
|
reload_dataloaders_every_n_epochs: int = 0 |
|
ipus: Optional[int] = None |
|
devices: Any = None |
|
strategy: Any = None |
|
enable_checkpointing: bool = False |
|
enable_model_summary: bool = True |
|
inference_mode: bool = True |
|
|
|
|
|
|
|
cs.store( |
|
group="trainer", name="trainer", node=TrainerConfig, |
|
) |
|
|