|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
This script runs model parallel text classification evaluation. |
|
""" |
|
import pytorch_lightning as pl |
|
from omegaconf import DictConfig, OmegaConf |
|
|
|
from nemo.collections.nlp.models.text_classification import TextClassificationModel |
|
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy |
|
from nemo.core.config import hydra_runner |
|
from nemo.utils import logging |
|
from nemo.utils.exp_manager import exp_manager |
|
|
|
|
|
@hydra_runner(config_path="conf", config_name="text_classification_config") |
|
def main(cfg: DictConfig) -> None: |
|
logging.info(f'\nConfig Params:\n{OmegaConf.to_yaml(cfg)}') |
|
trainer = pl.Trainer(strategy=NLPDDPStrategy(), **cfg.trainer) |
|
exp_manager(trainer, cfg.get("exp_manager", None)) |
|
|
|
model = TextClassificationModel.restore_from(cfg.model.nemo_path, trainer=trainer, strict=False) |
|
model.setup_test_data(test_data_config=cfg.model.test_ds) |
|
|
|
trainer.test(model=model, ckpt_path=None) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|