File size: 16,212 Bytes
7934b29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 |
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lJz6FDU1lRzc"
},
"outputs": [],
"source": [
"\"\"\"\n",
"You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n",
"\n",
"Instructions for setting up Colab are as follows:\n",
"1. Open a new Python 3 notebook.\n",
"2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n",
"3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n",
"4. Run this cell to set up dependencies.\n",
"5. Restart the runtime (Runtime -> Restart Runtime) for any upgraded packages to take effect\n",
"\"\"\"\n",
"# If you're using Google Colab and not running locally, run this cell.\n",
"\n",
"## Install dependencies\n",
"!pip install wget\n",
"!apt-get install sox libsndfile1 ffmpeg\n",
"!pip install text-unidecode\n",
"!pip install matplotlib>=3.3.2\n",
"\n",
"## Install NeMo\n",
"BRANCH = 'r1.17.0'\n",
"!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]\n",
"\n",
"## Grab the config we'll use in this example\n",
"!mkdir configs\n",
"!wget -P configs/ https://raw.githubusercontent.com/NVIDIA/NeMo/$BRANCH/examples/asr/conf/config.yaml\n",
"\n",
"\"\"\"\n",
"Remember to restart the runtime for the kernel to pick up any upgraded packages (e.g. matplotlib)!\n",
"Alternatively, you can uncomment the exit() below to crash and restart the kernel, in the case\n",
"that you want to use the \"Run All Cells\" (or similar) option.\n",
"\"\"\"\n",
"# exit()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "v1Jk9etFlRzf"
},
"source": [
"# Telephony speech (8 kHz)\n",
"This notebook covers general recommendations for using NeMo models with 8 kHz speech. All the pretrained models currently available through NeMo are trained with audio at 16 kHz. This means that if the original audio was sampled at a different rate, it's sampling rate was converted to 16 kHz through upsampling or downsampling. One of the common applications for ASR is to recognize telephony speech which typically consists of speech sampled at 8 kHz.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Mixed sample rate\n",
"Most of the pretrained English models distributed with NeMo are trained with mixed sample rate data, i.e. the training data typically consists of data sampled at both 8 kHz and 16 kHz. As an example pretrained Citrinet model \"stt_en_citrinet_1024\" was trained with the following datasets. \n",
"* Librispeech 960 hours of English speech\n",
"* Fisher Corpus\n",
"* Switchboard-1 Dataset\n",
"* WSJ-0 and WSJ-1\n",
"* National Speech Corpus - 1\n",
"* Mozilla Common Voice\n",
"\n",
"Among these, Fisher and Switchboard datasets are conversational telephone speech datasets with audio sampled at 8 kHz while the other datasets were originally sampled at least 16 kHz. Before training, all audio files from Fisher and Switchboard datasets were upsampled to 16 kHz. Because of this mixed sample rate training, our models can be used to recognize both narrowband (8kHz) and wideband speech (16kHz)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Inference with NeMo\n",
"NeMo ASR currently supports inference of audio in .wav format. Internally, the audio file is resampled to 16 kHz before inference is called on the model, so there is no difference running inference on 8 kHz audio compared to say 16 kHz or any other sampling rate audio with NeMo. Let's look at an example for running inference on 8 kHz audio. "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# This is where the an4/ directory will be placed.\n",
"# Change this if you don't want the data to be extracted in the current directory.\n",
"data_dir = '.'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import glob\n",
"import os\n",
"import subprocess\n",
"import tarfile\n",
"import wget\n",
"\n",
"# Download the dataset. This will take a few moments...\n",
"print(\"******\")\n",
"if not os.path.exists(data_dir + '/an4_sphere.tar.gz'):\n",
" an4_url = 'https://dldata-public.s3.us-east-2.amazonaws.com/an4_sphere.tar.gz' # for the original source, please visit http://www.speech.cs.cmu.edu/databases/an4/an4_sphere.tar.gz \n",
" an4_path = wget.download(an4_url, data_dir)\n",
" print(f\"Dataset downloaded at: {an4_path}\")\n",
"else:\n",
" print(\"Tarfile already exists.\")\n",
" an4_path = data_dir + '/an4_sphere.tar.gz'\n",
"\n",
"if not os.path.exists(data_dir + '/an4/'):\n",
" # Untar and convert .sph to .wav (using sox)\n",
" tar = tarfile.open(an4_path)\n",
" tar.extractall(path=data_dir)\n",
"\n",
" print(\"Converting .sph to .wav...\")\n",
" sph_list = glob.glob(data_dir + '/an4/**/*.sph', recursive=True)\n",
" for sph_path in sph_list:\n",
" wav_path = sph_path[:-4] + '.wav'\n",
" cmd = [\"sox\", sph_path, wav_path]\n",
" subprocess.run(cmd)\n",
"print(\"Finished conversion.\\n******\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Audio in an4 dataset is sampled at 22 kHz. Let's first downsample an audio file to 16 kHz."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import librosa\n",
"import IPython.display as ipd\n",
"import librosa.display\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Load and listen to the audio file\n",
"example_file = data_dir + '/an4/wav/an4_clstk/mgah/cen2-mgah-b.wav'\n",
"audio, sample_rate = librosa.load(example_file)\n",
"print(sample_rate)\n",
"audio_16kHz = librosa.core.resample(audio, orig_sr=sample_rate, target_sr=16000)\n",
"\n",
"import numpy as np\n",
"\n",
"# Get spectrogram using Librosa's Short-Time Fourier Transform (stft)\n",
"spec = np.abs(librosa.stft(audio_16kHz))\n",
"spec_db = librosa.amplitude_to_db(spec, ref=np.max) # Decibels\n",
"\n",
"# Use log scale to view frequencies\n",
"librosa.display.specshow(spec_db, y_axis='log', x_axis='time', sr=16000)\n",
"plt.colorbar()\n",
"plt.title('Audio Spectrogram');\n",
"plt.ylim([0, 8000])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, let's downsample the audio to 8 kHz"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"audio_8kHz = librosa.core.resample(audio, orig_sr=sample_rate, target_sr=8000)\n",
"spec = np.abs(librosa.stft(audio_8kHz))\n",
"spec_db = librosa.amplitude_to_db(spec, ref=np.max) # Decibels\n",
"\n",
"# Use log scale to view frequencies\n",
"librosa.display.specshow(spec_db, y_axis='log', x_axis='time', sr=8000)\n",
"plt.colorbar()\n",
"plt.title('Audio Spectrogram');\n",
"plt.ylim([0, 8000])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import soundfile as sf\n",
"sf.write(data_dir + '/audio_16kHz.wav', audio_16kHz, 16000)\n",
"sample, sr = librosa.core.load(data_dir + '/audio_16kHz.wav')\n",
"ipd.Audio(sample, rate=sr)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sf.write(data_dir + '/audio_8kHz.wav', audio_8kHz, 8000)\n",
"sample, sr = librosa.core.load(data_dir + '/audio_8kHz.wav')\n",
"ipd.Audio(sample, rate=sr)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"Let's look at inference results using one of the pretrained models on the original, 16 kHz and 8 kHz versions of the example file we chose above."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from nemo.collections.asr.models import ASRModel\n",
"import torch\n",
"if torch.cuda.is_available():\n",
" device = torch.device(f'cuda:0')\n",
"asr_model = ASRModel.from_pretrained(model_name='stt_en_citrinet_1024', map_location=device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As discussed above, there are no changes required for inference based on the sampling rate of audio and as we see below the pretrained Citrinet model gives accurate transcription even for audio downsampled to 8 Khz."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(asr_model.transcribe(paths2audio_files=[example_file]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(asr_model.transcribe(paths2audio_files=[data_dir + '/audio_16kHz.wav']))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(asr_model.transcribe(paths2audio_files=[data_dir + '/audio_8kHz.wav']))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training / fine-tuning with 8 kHz data\n",
"For training a model with new 8 kHz data, one could take two approaches. The first approach, **which is recommended**, is to finetune a pretrained 16 kHz model by upsampling all the data to 16 kHz. Note that upsampling offline before training is not necessary but recommended as online upsampling during training is very time consuming and may slow down training significantly. The second approach is to train an 8 kHz model from scratch. **Note**: For the second approach, in our experiments we saw that loading the weights of a 16 kHz model as initialization helps the model to converge faster with better accuracy.\n",
"\n",
"To upsample your 8 kHz data to 16 kHz command line tools like sox or ffmpeg are very useful. Here is the command to upsample and audio file using sox:\n",
"```shell\n",
"sox input_8k.wav -r 16000 -o output_16k.wav\n",
"```\n",
"Now to finetune a pre-trained model with this upsampled data, you can just restore the model weights from the pre-trained model and call trainer with the upsampled data. As an example, here is how one would fine-tune a Citrinet model:\n",
"```python\n",
"python examples/asr/script_to_bpe.py \\\n",
" --config-path=\"examples/asr/conf/citrinet\" \\\n",
" --config-name=\"citrinet_512.yaml\" \\\n",
" model.train_ds.manifest_filepath=\"<path to manifest file with upsampled 16kHz data>\" \\\n",
" model.validation_ds.manifest_filepath=\"<path to manifest file>\" \\\n",
" trainer.devices=-1 \\\n",
" trainer.accelerator='gpu' \\\n",
" trainer.max_epochs=50 \\\n",
" +init_from_pretrained_model=\"stt_en_citrinet_512\"\n",
"```\n",
"\n",
"To train an 8 kHz model, just change the sample rate in the config to 8000 as follows:\n",
"\n",
"```python\n",
"python examples/asr/script_to_bpe.py \\\n",
" --config-path=\"examples/asr/conf/citrinet\" \\\n",
" --config-name=\"citrinet_512.yaml\" \\\n",
" model.sample_rate=8000 \\\n",
" model.train_ds.manifest_filepath=\"<path to manifest file with 8kHz data>\" \\\n",
" model.validation_ds.manifest_filepath=\"<path to manifest file>\" \\\n",
" trainer.devices=-1 \\\n",
" trainer.accelerator='gpu' \\\n",
" trainer.max_epochs=50 \\\n",
" +init_from_pretrained_model=\"stt_en_citrinet_512\"\n",
"```"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "ASR_with_NeMo.ipynb",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
},
"pycharm": {
"stem_cell": {
"cell_type": "raw",
"metadata": {
"collapsed": false
},
"source": []
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}
|