|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import pytorch_lightning as pl |
|
from omegaconf import OmegaConf |
|
from pytorch_lightning import seed_everything |
|
|
|
from nemo.collections.asr.models import EncDecDiarLabelModel |
|
from nemo.core.config import hydra_runner |
|
from nemo.utils import logging |
|
from nemo.utils.exp_manager import exp_manager |
|
|
|
""" |
|
Example training session (single GPU training on telephonic datasets) |
|
|
|
python ./multiscale_diar_decoder.py --config-path='../conf/neural_diarizer' --config-name='msdd_5scl_15_05_50Povl_256x3x32x2.yaml' \ |
|
trainer.devices=1 \ |
|
model.base.diarizer.speaker_embeddings.model_path="titanet_large" \ |
|
model.train_ds.manifest_filepath="<train_manifest_path>" \ |
|
model.validation_ds.manifest_filepath="<dev_manifest_path>" \ |
|
model.train_ds.emb_dir="<train_temp_dir>" \ |
|
model.validation_ds.emb_dir="<dev_temp_dir>" \ |
|
exp_manager.name='sample_train' \ |
|
exp_manager.exp_dir='./msdd_exp' |
|
""" |
|
|
|
seed_everything(42) |
|
|
|
|
|
@hydra_runner(config_path="../conf/neural_diarizer", config_name="msdd_5scl_15_05_50Povl_256x3x32x2.yaml") |
|
def main(cfg): |
|
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') |
|
trainer = pl.Trainer(**cfg.trainer) |
|
exp_manager(trainer, cfg.get("exp_manager", None)) |
|
msdd_model = EncDecDiarLabelModel(cfg=cfg.model, trainer=trainer) |
|
trainer.fit(msdd_model) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|