|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
This script contains an example of how to train and test dialogue models in NeMo. |
|
|
|
***Setting the configs*** |
|
The model and the PT trainer are defined in a config file that declares multiple important sections. |
|
The most important ones are: |
|
model: All arguments that are related to the Model - model, loss, optimizer, |
|
schedulers, and datasets/data loaders. |
|
trainer: Any argument to be passed to PyTorch Lightning including number of epochs, number of GPUs, |
|
precision level, etc. |
|
|
|
This script uses the `/examples/nlp/dialogue_state_tracking/conf/dialog_config.yaml` config file |
|
by default. You may update the config file from the file directly. The other option is to set another config file via command-line arguments by `--config-name=CONFIG_FILE_PATH'. |
|
|
|
|
|
***Model Training*** |
|
python dialogue.py |
|
do_training=True |
|
model.dataset.data_dir=<DATA_DIR_WITH_JSON_DATA> |
|
model.dataset.dialogues_example_dir=<DAT_DIR_FOR_CACHING_INTERMEDIATE_AND_SAVING_PREDICTIONS> |
|
model.dataset.task=<TASK - see conf/dialogue_config.yaml for full list> e.g. sgd |
|
model.language_model.pretrained_model_name=<TASK - see conf/dialogue_config.yaml for full list> e.g. gpt2 |
|
trainer.devices=[<DEVICE_IDS_TO_USE>] |
|
|
|
***Model Evaluation*** |
|
command as above, change do_training=False |
|
""" |
|
|
|
import os |
|
|
|
import pytorch_lightning as pl |
|
from omegaconf import DictConfig, OmegaConf |
|
|
|
from nemo.collections.nlp.models.dialogue.dialogue_gpt_classification_model import DialogueGPTClassificationModel |
|
from nemo.collections.nlp.models.dialogue.dialogue_gpt_generation_model import DialogueGPTGenerationModel |
|
from nemo.collections.nlp.models.dialogue.dialogue_nearest_neighbour_model import DialogueNearestNeighbourModel |
|
from nemo.collections.nlp.models.dialogue.dialogue_s2s_generation_model import DialogueS2SGenerationModel |
|
from nemo.collections.nlp.models.dialogue.dialogue_zero_shot_intent_model import DialogueZeroShotIntentModel |
|
from nemo.collections.nlp.models.dialogue.intent_slot_classification_model import IntentSlotClassificationModel |
|
from nemo.collections.nlp.models.dialogue.sgdqa_model import SGDQAModel |
|
from nemo.collections.nlp.modules.common.megatron.megatron_utils import compute_model_parallel_rank |
|
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy |
|
from nemo.core.config import hydra_runner |
|
from nemo.utils import logging |
|
from nemo.utils.app_state import AppState |
|
from nemo.utils.exp_manager import exp_manager |
|
|
|
|
|
@hydra_runner(config_path="conf", config_name="dialogue_config") |
|
def main(cfg: DictConfig) -> None: |
|
pl.seed_everything(42) |
|
logging.info(f'Config: {OmegaConf.to_yaml(cfg)}') |
|
|
|
try: |
|
strategy = NLPDDPStrategy(no_ddp_communication_hook=True, find_unused_parameters=True,) |
|
except (ImportError, ModuleNotFoundError): |
|
strategy = None |
|
|
|
trainer = pl.Trainer(**cfg.trainer, strategy=strategy) |
|
|
|
exp_manager(trainer, cfg.get("exp_manager", None)) |
|
|
|
app_state = AppState() |
|
app_state.data_parallel_size = cfg.model.data_parallel_size |
|
if cfg.model.tensor_model_parallel_size > 1: |
|
app_state.model_parallel_size = cfg.model.tensor_model_parallel_size |
|
app_state.tensor_model_parallel_rank = compute_model_parallel_rank( |
|
trainer.local_rank, app_state.model_parallel_size |
|
) |
|
|
|
if 'bert' in cfg.model.language_model.pretrained_model_name: |
|
if cfg.model.dataset.task == 'sgd': |
|
if cfg.model.original_nemo_checkpoint is not None: |
|
model_class = DialogueZeroShotIntentModel |
|
else: |
|
model_class = SGDQAModel |
|
elif cfg.model.dataset.task in ['zero_shot', 'design']: |
|
model_class = DialogueZeroShotIntentModel |
|
else: |
|
model_class = IntentSlotClassificationModel |
|
elif 'gpt' in cfg.model.language_model.pretrained_model_name.lower(): |
|
if cfg.model.dataset.task in ['ms_marco', 'mellon_qa']: |
|
model_class = DialogueGPTGenerationModel |
|
else: |
|
model_class = DialogueGPTClassificationModel |
|
elif ( |
|
'bart' in cfg.model.language_model.pretrained_model_name.lower() |
|
or 't5' in cfg.model.language_model.pretrained_model_name.lower() |
|
): |
|
|
|
|
|
model_class = DialogueS2SGenerationModel |
|
elif 'sentence-transformers' in cfg.model.language_model.pretrained_model_name.lower(): |
|
model_class = DialogueNearestNeighbourModel |
|
|
|
if cfg.pretrained_model or (cfg.model.nemo_path and os.path.exists(cfg.model.nemo_path)): |
|
if cfg.pretrained_model: |
|
logging.info(f'Loading pretrained model {cfg.pretrained_model}') |
|
model = model_class.from_pretrained(cfg.pretrained_model) |
|
else: |
|
logging.info(f'Restoring model from {cfg.model.nemo_path}') |
|
model = model_class.restore_from(cfg.model.nemo_path, trainer=trainer) |
|
|
|
if cfg.do_training: |
|
model.setup_training_data(train_data_config=cfg.model.train_ds) |
|
model.setup_multiple_validation_data(val_data_config=cfg.model.validation_ds) |
|
else: |
|
logging.info(f'Config: {OmegaConf.to_yaml(cfg)}') |
|
model = model_class(cfg.model, trainer=trainer) |
|
|
|
if cfg.do_training: |
|
trainer.fit(model) |
|
if cfg.model.nemo_path: |
|
if not os.path.exists(cfg.model.nemo_path): |
|
model.save_to(cfg.model.nemo_path) |
|
else: |
|
updated_nemo_path = cfg.model.nemo_path.replace(".nemo", "_new.nemo") |
|
logging.warning("nemo path exists, saving at {} instead".format(updated_nemo_path)) |
|
model.save_to(updated_nemo_path) |
|
|
|
else: |
|
data_dir = cfg.model.dataset.get('data_dir', None) |
|
dialogues_example_dir = cfg.model.dataset.get('dialogues_example_dir', None) |
|
|
|
if data_dir is None or dialogues_example_dir is None: |
|
raise ValueError('No dataset directory provided. Skipping evaluation. ') |
|
elif not os.path.exists(data_dir): |
|
raise ValueError(f'{data_dir} is not found, skipping evaluation on the test set.') |
|
else: |
|
if hasattr(model, "update_data_dirs"): |
|
model.update_data_dirs(data_dir=data_dir, dialogues_example_dir=dialogues_example_dir) |
|
model._cfg.dataset = cfg.model.dataset |
|
|
|
if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.ds_item is not None: |
|
model.setup_multiple_test_data(test_data_config=cfg.model.test_ds) |
|
trainer.test(model) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|