File size: 16,212 Bytes
7934b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
{
    "cells": [
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {
                "id": "lJz6FDU1lRzc"
            },
            "outputs": [],
            "source": [
                "\"\"\"\n",
                "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n",
                "\n",
                "Instructions for setting up Colab are as follows:\n",
                "1. Open a new Python 3 notebook.\n",
                "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n",
                "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n",
                "4. Run this cell to set up dependencies.\n",
                "5. Restart the runtime (Runtime -> Restart Runtime) for any upgraded packages to take effect\n",
                "\"\"\"\n",
                "# If you're using Google Colab and not running locally, run this cell.\n",
                "\n",
                "## Install dependencies\n",
                "!pip install wget\n",
                "!apt-get install sox libsndfile1 ffmpeg\n",
                "!pip install text-unidecode\n",
                "!pip install matplotlib>=3.3.2\n",
                "\n",
                "## Install NeMo\n",
                "BRANCH = 'r1.17.0'\n",
                "!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]\n",
                "\n",
                "## Grab the config we'll use in this example\n",
                "!mkdir configs\n",
                "!wget -P configs/ https://raw.githubusercontent.com/NVIDIA/NeMo/$BRANCH/examples/asr/conf/config.yaml\n",
                "\n",
                "\"\"\"\n",
                "Remember to restart the runtime for the kernel to pick up any upgraded packages (e.g. matplotlib)!\n",
                "Alternatively, you can uncomment the exit() below to crash and restart the kernel, in the case\n",
                "that you want to use the \"Run All Cells\" (or similar) option.\n",
                "\"\"\"\n",
                "# exit()"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "v1Jk9etFlRzf"
            },
            "source": [
                "# Telephony speech (8 kHz)\n",
                "This notebook covers general recommendations for using NeMo models with 8 kHz speech. All the pretrained models currently available through NeMo are trained with audio at 16 kHz. This means that if the original audio was sampled at a different rate, it's sampling rate was converted to 16 kHz through upsampling or downsampling. One of the common applications for ASR is to recognize telephony speech which typically consists of speech sampled at 8 kHz.\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Mixed sample rate\n",
                "Most of the pretrained English models distributed with NeMo are trained with mixed sample rate data, i.e. the training data typically consists of data sampled at both 8 kHz and 16 kHz. As an example pretrained Citrinet model \"stt_en_citrinet_1024\" was trained with the following datasets. \n",
                "* Librispeech 960 hours of English speech\n",
                "* Fisher Corpus\n",
                "* Switchboard-1 Dataset\n",
                "* WSJ-0 and WSJ-1\n",
                "* National Speech Corpus - 1\n",
                "* Mozilla Common Voice\n",
                "\n",
                "Among these, Fisher and Switchboard datasets are conversational telephone speech datasets with audio sampled at 8 kHz while the other datasets were originally sampled at least 16 kHz. Before training, all audio files from Fisher and Switchboard datasets were upsampled to 16 kHz. Because of this mixed sample rate training, our models can be used to recognize both narrowband (8kHz) and wideband speech (16kHz)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Inference with NeMo\n",
                "NeMo ASR currently supports inference of audio in .wav format. Internally, the audio file is resampled to 16 kHz before inference is called on the model, so there is no difference running inference on 8 kHz audio compared to say 16 kHz or any other sampling rate audio with NeMo. Let's look at an example for running inference on 8 kHz audio. "
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 3,
            "metadata": {},
            "outputs": [],
            "source": [
                "# This is where the an4/ directory will be placed.\n",
                "# Change this if you don't want the data to be extracted in the current directory.\n",
                "data_dir = '.'"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import glob\n",
                "import os\n",
                "import subprocess\n",
                "import tarfile\n",
                "import wget\n",
                "\n",
                "# Download the dataset. This will take a few moments...\n",
                "print(\"******\")\n",
                "if not os.path.exists(data_dir + '/an4_sphere.tar.gz'):\n",
                "    an4_url = 'https://dldata-public.s3.us-east-2.amazonaws.com/an4_sphere.tar.gz'  # for the original source, please visit http://www.speech.cs.cmu.edu/databases/an4/an4_sphere.tar.gz \n",
                "    an4_path = wget.download(an4_url, data_dir)\n",
                "    print(f\"Dataset downloaded at: {an4_path}\")\n",
                "else:\n",
                "    print(\"Tarfile already exists.\")\n",
                "    an4_path = data_dir + '/an4_sphere.tar.gz'\n",
                "\n",
                "if not os.path.exists(data_dir + '/an4/'):\n",
                "    # Untar and convert .sph to .wav (using sox)\n",
                "    tar = tarfile.open(an4_path)\n",
                "    tar.extractall(path=data_dir)\n",
                "\n",
                "    print(\"Converting .sph to .wav...\")\n",
                "    sph_list = glob.glob(data_dir + '/an4/**/*.sph', recursive=True)\n",
                "    for sph_path in sph_list:\n",
                "        wav_path = sph_path[:-4] + '.wav'\n",
                "        cmd = [\"sox\", sph_path, wav_path]\n",
                "        subprocess.run(cmd)\n",
                "print(\"Finished conversion.\\n******\")"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "Audio in an4 dataset is sampled at 22 kHz. Let's first downsample an audio file to 16 kHz."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import librosa\n",
                "import IPython.display as ipd\n",
                "import librosa.display\n",
                "import matplotlib.pyplot as plt\n",
                "\n",
                "# Load and listen to the audio file\n",
                "example_file = data_dir + '/an4/wav/an4_clstk/mgah/cen2-mgah-b.wav'\n",
                "audio, sample_rate = librosa.load(example_file)\n",
                "print(sample_rate)\n",
                "audio_16kHz = librosa.core.resample(audio, orig_sr=sample_rate, target_sr=16000)\n",
                "\n",
                "import numpy as np\n",
                "\n",
                "# Get spectrogram using Librosa's Short-Time Fourier Transform (stft)\n",
                "spec = np.abs(librosa.stft(audio_16kHz))\n",
                "spec_db = librosa.amplitude_to_db(spec, ref=np.max)  # Decibels\n",
                "\n",
                "# Use log scale to view frequencies\n",
                "librosa.display.specshow(spec_db, y_axis='log', x_axis='time', sr=16000)\n",
                "plt.colorbar()\n",
                "plt.title('Audio Spectrogram');\n",
                "plt.ylim([0, 8000])"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "Now, let's downsample the audio to 8 kHz"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "audio_8kHz = librosa.core.resample(audio, orig_sr=sample_rate, target_sr=8000)\n",
                "spec = np.abs(librosa.stft(audio_8kHz))\n",
                "spec_db = librosa.amplitude_to_db(spec, ref=np.max)  # Decibels\n",
                "\n",
                "# Use log scale to view frequencies\n",
                "librosa.display.specshow(spec_db, y_axis='log', x_axis='time',  sr=8000)\n",
                "plt.colorbar()\n",
                "plt.title('Audio Spectrogram');\n",
                "plt.ylim([0, 8000])"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import soundfile as sf\n",
                "sf.write(data_dir + '/audio_16kHz.wav', audio_16kHz, 16000)\n",
                "sample, sr = librosa.core.load(data_dir + '/audio_16kHz.wav')\n",
                "ipd.Audio(sample, rate=sr)\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "sf.write(data_dir + '/audio_8kHz.wav', audio_8kHz, 8000)\n",
                "sample, sr = librosa.core.load(data_dir + '/audio_8kHz.wav')\n",
                "ipd.Audio(sample, rate=sr)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "\n",
                "Let's look at inference results using one of the pretrained models on the original, 16 kHz and 8 kHz versions of the example file we chose above."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from nemo.collections.asr.models import ASRModel\n",
                "import torch\n",
                "if torch.cuda.is_available():\n",
                "    device = torch.device(f'cuda:0')\n",
                "asr_model = ASRModel.from_pretrained(model_name='stt_en_citrinet_1024', map_location=device)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "As discussed above, there are no changes required for inference based on the sampling rate of audio and as we see below the pretrained Citrinet model gives accurate transcription even for audio downsampled to 8 Khz."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "print(asr_model.transcribe(paths2audio_files=[example_file]))"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "print(asr_model.transcribe(paths2audio_files=[data_dir + '/audio_16kHz.wav']))"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "print(asr_model.transcribe(paths2audio_files=[data_dir + '/audio_8kHz.wav']))"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Training / fine-tuning with 8 kHz data\n",
                "For training a model with new 8 kHz data, one could take two approaches. The first approach, **which is recommended**, is to finetune a pretrained 16 kHz model by upsampling all the data to 16 kHz. Note that upsampling offline before training is not necessary but recommended as online upsampling during training is very time consuming and may slow down training significantly. The second approach is to train an 8 kHz model from scratch. **Note**: For the second approach, in our experiments we saw that loading the weights of a 16 kHz model as initialization helps the model to converge faster with better accuracy.\n",
                "\n",
                "To upsample your 8 kHz data to 16 kHz command line tools like sox or ffmpeg are very useful. Here is the command to upsample and audio file using sox:\n",
                "```shell\n",
                "sox input_8k.wav -r 16000 -o output_16k.wav\n",
                "```\n",
                "Now to finetune a pre-trained model with this upsampled data, you can just restore the model weights from the pre-trained model and call trainer with the upsampled data. As an example, here is how one would fine-tune a Citrinet model:\n",
                "```python\n",
                "python examples/asr/script_to_bpe.py \\\n",
                "    --config-path=\"examples/asr/conf/citrinet\" \\\n",
                "    --config-name=\"citrinet_512.yaml\" \\\n",
                "    model.train_ds.manifest_filepath=\"<path to manifest file with upsampled 16kHz data>\" \\\n",
                "    model.validation_ds.manifest_filepath=\"<path to manifest file>\" \\\n",
                "    trainer.devices=-1 \\\n",
                "    trainer.accelerator='gpu' \\\n",
                "    trainer.max_epochs=50 \\\n",
                "    +init_from_pretrained_model=\"stt_en_citrinet_512\"\n",
                "```\n",
                "\n",
                "To train an 8 kHz model, just change the sample rate in the config to 8000 as follows:\n",
                "\n",
                "```python\n",
                "python examples/asr/script_to_bpe.py \\\n",
                "    --config-path=\"examples/asr/conf/citrinet\" \\\n",
                "    --config-name=\"citrinet_512.yaml\" \\\n",
                "    model.sample_rate=8000 \\\n",
                "    model.train_ds.manifest_filepath=\"<path to manifest file with 8kHz data>\" \\\n",
                "    model.validation_ds.manifest_filepath=\"<path to manifest file>\" \\\n",
                "    trainer.devices=-1 \\\n",
                "    trainer.accelerator='gpu' \\\n",
                "    trainer.max_epochs=50 \\\n",
                "    +init_from_pretrained_model=\"stt_en_citrinet_512\"\n",
                "```"
            ]
        }
    ],
    "metadata": {
        "accelerator": "GPU",
        "colab": {
            "collapsed_sections": [],
            "name": "ASR_with_NeMo.ipynb",
            "provenance": [],
            "toc_visible": true
        },
        "kernelspec": {
            "display_name": "Python 3",
            "language": "python",
            "name": "python3"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.8.5"
        },
        "pycharm": {
            "stem_cell": {
                "cell_type": "raw",
                "metadata": {
                    "collapsed": false
                },
                "source": []
            }
        }
    },
    "nbformat": 4,
    "nbformat_minor": 4
}