|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import os |
|
from dataclasses import dataclass, is_dataclass |
|
from typing import Optional |
|
|
|
import pytorch_lightning as pl |
|
import torch |
|
from omegaconf import OmegaConf |
|
|
|
from nemo.collections.tts.models.g2p_heteronym_classification import HeteronymClassificationModel |
|
from nemo.core.config import hydra_runner |
|
from nemo.utils import logging |
|
|
|
""" |
|
This script runs inference with HeteronymClassificationModel |
|
If the input manifest contains target "word_id", evaluation will be also performed. |
|
|
|
To prepare dataset, see NeMo/scripts/dataset_processing/g2p/export_wikihomograph_data_to_manifest.py |
|
|
|
Inference form manifest: |
|
|
|
python g2p_heteronym_classification_inference.py \ |
|
manifest="<Path to .json manifest>" \ |
|
pretrained_model="<Path to .nemo file or pretrained model name from list_available_models()>" \ |
|
output_manifest="<Path to .json manifest to save prediction>" \ |
|
wordid_to_phonemes_file="<Path to a file with mapping from wordid predicted by the model to phonemes>" |
|
|
|
Interactive inference: |
|
|
|
python g2p_heteronym_classification_inference.py \ |
|
pretrained_model="<Path to .nemo file or pretrained model name from list_available_models()>" \ |
|
wordid_to_phonemes_file="<Path to a file with mapping from wordid predicted by the model to phonemes>" # Optional |
|
|
|
""" |
|
|
|
|
|
@dataclass |
|
class TranscriptionConfig: |
|
|
|
pretrained_model: str |
|
|
|
|
|
manifest: Optional[str] = None |
|
output_manifest: Optional[ |
|
str |
|
] = "predictions.json" |
|
grapheme_field: str = "text_graphemes" |
|
|
|
|
|
|
|
wordid_to_phonemes_file: Optional[str] = None |
|
|
|
|
|
errors_file: Optional[str] = None |
|
batch_size: int = 32 |
|
num_workers: int = 0 |
|
|
|
|
|
@hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig) |
|
def main(cfg): |
|
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') |
|
|
|
if is_dataclass(cfg): |
|
cfg = OmegaConf.structured(cfg) |
|
|
|
if not cfg.pretrained_model: |
|
raise ValueError( |
|
'To run evaluation and inference script a pre-trained model or .nemo file must be provided.' |
|
f'Choose from {HeteronymClassificationModel.list_available_models()} or "pretrained_model"="your_model.nemo"' |
|
) |
|
|
|
logging.info( |
|
'During evaluation/testing, it is currently advisable to construct a new Trainer with single GPU and \ |
|
no DDP to obtain accurate results' |
|
) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
device = [0] |
|
accelerator = 'gpu' |
|
else: |
|
device = 1 |
|
accelerator = 'cpu' |
|
|
|
map_location = torch.device('cuda:{}'.format(device[0]) if accelerator == 'gpu' else 'cpu') |
|
trainer = pl.Trainer(devices=device, accelerator=accelerator, logger=False, enable_checkpointing=False) |
|
|
|
if os.path.exists(cfg.pretrained_model): |
|
model = HeteronymClassificationModel.restore_from(cfg.pretrained_model, map_location=map_location) |
|
elif cfg.pretrained_model in HeteronymClassificationModel.get_available_model_names(): |
|
model = HeteronymClassificationModel.from_pretrained(cfg.pretrained_model, map_location=map_location) |
|
else: |
|
raise ValueError( |
|
f'Provide path to the pre-trained .nemo checkpoint or choose from {HeteronymClassificationModel.list_available_models()}' |
|
) |
|
model.set_trainer(trainer) |
|
model = model.eval() |
|
|
|
logging.info(f'Config Params: {model._cfg}') |
|
|
|
if cfg.manifest is not None: |
|
if not os.path.exists(cfg.manifest): |
|
raise ValueError(f"{cfg.manifest} not found.") |
|
with torch.no_grad(): |
|
model.disambiguate_manifest( |
|
manifest=cfg.manifest, |
|
output_manifest=cfg.output_manifest, |
|
grapheme_field=cfg.grapheme_field, |
|
batch_size=cfg.batch_size, |
|
num_workers=cfg.num_workers, |
|
) |
|
|
|
|
|
if cfg.errors_file is None: |
|
cfg.errors_file = cfg.output_manifest.replace(".json", "_errors.txt") |
|
|
|
save_errors = True |
|
correct = 0 |
|
total = 0 |
|
with open(cfg.output_manifest, "r", encoding="utf-8") as f_preds, open( |
|
cfg.errors_file, "w", encoding="utf-8" |
|
) as f_errors: |
|
for line in f_preds: |
|
line = json.loads(line) |
|
predictions = line["pred_wordid"] |
|
|
|
if "word_id" in line: |
|
targets = line["word_id"] |
|
if isinstance(targets, str): |
|
targets = [targets] |
|
for idx, target_ in enumerate(targets): |
|
total += 1 |
|
if idx >= len(predictions) or target_ != predictions[idx]: |
|
f_errors.write(f"INPUT: {line[cfg.grapheme_field]}\n") |
|
f_errors.write(f"PRED : {predictions[idx]} -- GT: {target_}\n") |
|
f_errors.write("===========================\n") |
|
else: |
|
correct += 1 |
|
else: |
|
save_errors = False |
|
if save_errors: |
|
logging.info(f"Accuracy: {round(correct / total * 100, 2)}% ({total - correct} errors out of {total})") |
|
logging.info(f"Errors saved at {cfg.errors_file}") |
|
else: |
|
logging.info("No 'word_id' values found, skipping evaluation.") |
|
if os.path.exists(cfg.errors_file): |
|
os.remove(cfg.errors_file) |
|
else: |
|
print('Entering interactive mode.') |
|
done = False |
|
while not done: |
|
print('Type "STOP" to exit.') |
|
test_input = input('Input a test input:') |
|
if test_input == "STOP": |
|
done = True |
|
if not done: |
|
with torch.no_grad(): |
|
_, sentences = model.disambiguate( |
|
sentences=[test_input], |
|
batch_size=1, |
|
num_workers=cfg.num_workers, |
|
wordid_to_phonemes_file=cfg.wordid_to_phonemes_file, |
|
) |
|
print(sentences[0]) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|