File size: 3,766 Bytes
7934b29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytorch_lightning as pl
from omegaconf import DictConfig, OmegaConf
from nemo.collections.nlp.models.question_answering.qa_bert_model import BERTQAModel
from nemo.collections.nlp.models.question_answering.qa_gpt_model import GPTQAModel
from nemo.collections.nlp.models.question_answering.qa_s2s_model import S2SQAModel
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager
@hydra_runner(config_path="conf", config_name="qa_conf")
def main(cfg: DictConfig) -> None:
pl.seed_everything(42)
logging.info(f'Config: {OmegaConf.to_yaml(cfg)}')
trainer = pl.Trainer(**cfg.trainer)
exp_dir = exp_manager(trainer, cfg.get("exp_manager", None))
if "bert" in cfg.model.language_model.pretrained_model_name.lower():
model_class = BERTQAModel
elif "gpt" in cfg.model.language_model.pretrained_model_name.lower():
model_class = GPTQAModel
elif (
"bart" in cfg.model.language_model.pretrained_model_name.lower()
or "t5" in cfg.model.language_model.pretrained_model_name.lower()
):
model_class = S2SQAModel
if cfg.pretrained_model or (cfg.model.nemo_path and os.path.exists(cfg.model.nemo_path)):
if cfg.pretrained_model:
logging.info(f'Loading pretrained model {cfg.pretrained_model}')
model = model_class.from_pretrained(cfg.pretrained_model)
else:
logging.info(f'Restoring model from {cfg.model.nemo_path}')
model = model_class.restore_from(cfg.model.nemo_path)
if cfg.do_training:
model.setup_training_data(train_data_config=cfg.model.train_ds)
model.setup_multiple_validation_data(val_data_config=cfg.model.validation_ds)
else:
logging.info(f'Config: {OmegaConf.to_yaml(cfg)}')
model = model_class(cfg.model, trainer=trainer)
if cfg.do_training:
trainer.fit(model)
if cfg.model.nemo_path:
model.save_to(cfg.model.nemo_path)
if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.file is not None:
eval_device = [cfg.trainer.devices[0]] if isinstance(cfg.trainer.devices, list) else 1
trainer = pl.Trainer(devices=eval_device, accelerator=cfg.trainer.accelerator, precision=16)
model.setup_test_data(test_data_config=cfg.model.test_ds)
trainer.test(model)
# specifiy .json file to dump predictions. e.g. os.path.join(exp_dir, "output_nbest_file.json")
output_nbest_file = None
# specifiy .json file to dump predictions. e.g. os.path.join(exp_dir, "output_prediction_file.json")
output_prediction_file = None
inference_samples = 5 # for test purposes. To use entire inference dataset set to -1
all_preds, all_nbest = model.inference(
cfg.model.test_ds.file,
output_prediction_file=output_prediction_file,
output_nbest_file=output_nbest_file,
num_samples=inference_samples,
)
for question_id in all_preds:
print(all_preds[question_id])
if __name__ == "__main__":
main()
|