File size: 4,534 Bytes
7934b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import glob
import os

import torch

import nemo.collections.asr as nemo_asr
from nemo.utils import logging, model_utils

# setup AMP (optional)
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'):
    logging.info("AMP enabled!\n")
    autocast = torch.cuda.amp.autocast
else:

    @contextlib.contextmanager
    def autocast():
        yield


MODEL_CACHE = {}

# Special tags for fallbacks / user notifications
TAG_ERROR_DURING_TRANSCRIPTION = "<ERROR_DURING_TRANSCRIPTION>"


def get_model_names():
    # Populate local copy of models
    local_model_paths = glob.glob(os.path.join('models', "**", "*.nemo"), recursive=True)
    local_model_names = list(sorted([os.path.basename(path) for path in local_model_paths]))

    # Populate with pretrained nemo checkpoint list
    nemo_model_names = set()
    for model_info in nemo_asr.models.ASRModel.list_available_models():
        for superclass in model_info.class_.mro():
            if 'CTC' in superclass.__name__ or 'RNNT' in superclass.__name__:
                if 'align' in model_info.pretrained_model_name:
                    continue

                nemo_model_names.add(model_info.pretrained_model_name)
    nemo_model_names = list(sorted(nemo_model_names))
    return nemo_model_names, local_model_names


def initialize_model(model_name):
    # load model
    if model_name not in MODEL_CACHE:
        if '.nemo' in model_name:
            # use local model
            model_name_no_ext = os.path.splitext(model_name)[0]
            model_path = os.path.join('models', model_name_no_ext, model_name)

            # Extract config
            model_cfg = nemo_asr.models.ASRModel.restore_from(restore_path=model_path, return_config=True)
            classpath = model_cfg.target  # original class path
            imported_class = model_utils.import_class_by_path(classpath)  # type: ASRModel
            logging.info(f"Restoring local model : {imported_class.__name__}")

            # load model from checkpoint
            model = imported_class.restore_from(restore_path=model_path, map_location='cpu')  # type: ASRModel

        else:
            # use pretrained model
            model = nemo_asr.models.ASRModel.from_pretrained(model_name, map_location='cpu')

        model.freeze()

        # cache model
        MODEL_CACHE[model_name] = model

    model = MODEL_CACHE[model_name]
    return model


def transcribe_all(filepaths, model_name, use_gpu_if_available=True):
    # instantiate model
    if model_name in MODEL_CACHE:
        model = MODEL_CACHE[model_name]
    else:
        model = initialize_model(model_name)

    if torch.cuda.is_available() and use_gpu_if_available:
        model = model.cuda()

    # transcribe audio
    logging.info("Begin transcribing audio...")
    try:
        with autocast():
            with torch.no_grad():
                transcriptions = model.transcribe(filepaths, batch_size=32)

    except RuntimeError:
        # Purge the cache to clear some memory
        MODEL_CACHE.clear()

        logging.info("Ran out of memory on device - performing inference on CPU for now")

        try:
            model = model.cpu()

            with torch.no_grad():
                transcriptions = model.transcribe(filepaths, batch_size=32)

        except Exception as e:
            logging.info(f"Exception {e} occured while attemting to transcribe audio. Returning error message")
            return TAG_ERROR_DURING_TRANSCRIPTION

    logging.info(f"Finished transcribing {len(filepaths)} files !")

    # If RNNT models transcribe, they return a tuple (greedy, beam_scores)
    if type(transcriptions[0]) == list and len(transcriptions) == 2:
        # get greedy transcriptions only
        transcriptions = transcriptions[0]

    # Force onto CPU
    model = model.cpu()
    MODEL_CACHE[model_name] = model

    return transcriptions