File size: 26,644 Bytes
7934b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "lJz6FDU1lRzc"
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n",
    "\n",
    "Instructions for setting up Colab are as follows:\n",
    "1. Open a new Python 3 notebook.\n",
    "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n",
    "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n",
    "4. Run this cell to set up dependencies.\n",
    "5. Restart the runtime (Runtime -> Restart Runtime) for any upgraded packages to take effect\n",
    "\"\"\"\n",
    "# If you're using Google Colab and not running locally, run this cell.\n",
    "\n",
    "## Install dependencies\n",
    "!pip install wget\n",
    "!apt-get install sox libsndfile1 ffmpeg\n",
    "!pip install text-unidecode\n",
    "!pip install matplotlib>=3.3.2\n",
    "\n",
    "## Install NeMo\n",
    "BRANCH = 'r1.17.0'\n",
    "!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]\n",
    "\n",
    "## Grab the config we'll use in this example\n",
    "!mkdir configs\n",
    "!wget -P configs/ https://raw.githubusercontent.com/NVIDIA/NeMo/$BRANCH/examples/asr/conf/config.yaml\n",
    "\n",
    "\"\"\"\n",
    "Remember to restart the runtime for the kernel to pick up any upgraded packages (e.g. matplotlib)!\n",
    "Alternatively, you can uncomment the exit() below to crash and restart the kernel, in the case\n",
    "that you want to use the \"Run All Cells\" (or similar) option.\n",
    "\"\"\"\n",
    "# exit()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Streaming ASR\n",
    "In this tutorial, we will look at one way to use one of NeMo's pretrained Conformer-CTC models for streaming inference. We will first look at some use cases where we may need streaming inference and then we will work towards developing a method for transcribing a long audio file using streaming."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "v1Jk9etFlRzf"
   },
   "source": [
    "# Why Stream?\n",
    "Streaming inference may be needed in one of the following scenarios:\n",
    "* Real-time or close to real-time inference for live transcriptions\n",
    "* Offline transcriptions of very long audio\n",
    "\n",
    "In this tutorial, we will mainly focus on streaming for handling long form audio and close to real-time inference with CTC based models. For training ASR models we usually use short segments of audio (<20s) that may be smaller chunks of a long audio that is aligned with the transcriptions and segmented into smaller chunks (see [tools/](https://github.com/NVIDIA/NeMo/tree/main/tools) for some great tools to do this). For running inference on long audio files we are restricted by the available GPU memory that dictates the maximum length of audio that can be transcribed in one inference call. We will take a look at one of the ways to overcome this restriction using NeMo's Conformer-CTC ASR model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Conformer-CTC\n",
    "Conformer-CTC models distributed with NeMo use a combination of self-attention and convolution modules to achieve the best of the two approaches, the self-attention layers can learn the global interaction while the convolutions efficiently capture the local correlations. Use of self-attention layers comes with a cost of increased memory usage at a quadratic rate with the sequence length. That means that transcribing long audio files with Conformer-CTC models needs streaming inference to break up the audio into smaller chunks. We will develop one method to do such inference through the course of this tutorial."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data\n",
    "To demonstrate transcribing a long audio file we will use utterances from the dev-clean set of the [mini Librispeech corpus](https://www.openslr.org/31/).  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# If something goes wrong during data processing, un-comment the following line to delete the cached dataset \n",
    "# !rm -rf datasets/mini-dev-clean\n",
    "!mkdir -p datasets/mini-dev-clean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python ../../scripts/dataset_processing/get_librispeech_data.py \\\n",
    "  --data_root \"datasets/mini-dev-clean/\" \\\n",
    "  --data_sets dev_clean_2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "manifest = \"datasets/mini-dev-clean/dev_clean_2.json\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's create a long audio that is about 15 minutes long by concatenating audio from dev-clean  and also create the corresponding concatenated transcript."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "def concat_audio(manifest_file, final_len=3600):\n",
    "    concat_len = 0\n",
    "    final_transcript = \"\"\n",
    "    with open(\"concat_file.txt\", \"w\") as cat_f:\n",
    "        while concat_len < final_len:\n",
    "            with open(manifest_file, \"r\") as mfst_f:\n",
    "                for l in mfst_f:\n",
    "                    row = json.loads(l.strip())\n",
    "                    if concat_len >= final_len:\n",
    "                        break\n",
    "                    cat_f.write(f\"file {row['audio_filepath']}\\n\")\n",
    "                    final_transcript += (\" \" + row['text'])\n",
    "                    concat_len += float(row['duration'])\n",
    "    return concat_len, final_transcript\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_duration, ref_transcript = concat_audio(manifest, 15*60)\n",
    "\n",
    "concat_audio_path = \"datasets/mini-dev-clean/concatenated_audio.wav\"\n",
    "\n",
    "!ffmpeg -t {new_duration} -safe 0 -f concat -i concat_file.txt -c copy -t {new_duration} {concat_audio_path} -y\n",
    "print(\"Finished concatenating audio file!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Streaming with CTC based models\n",
    "Now let's try to transcribe the long audio file created above using a conformer-large model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import nemo.collections.asr as nemo_asr\n",
    "import contextlib\n",
    "import gc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "device"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We are mainly concerned about decoding on the GPU in this tutorial. CPU decoding may be able to handle longer files but would also not be as fast as GPU decoding. Let's check if we can run transcribe() on the long audio file that we created above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Clear up memory\n",
    "torch.cuda.empty_cache()\n",
    "gc.collect()\n",
    "model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(\"stt_en_conformer_ctc_large\", map_location=device)\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "# device = 'cpu'  # You can transcribe even longer samples on the CPU, though it will take much longer !\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Helper for torch amp autocast\n",
    "if torch.cuda.is_available():\n",
    "    autocast = torch.cuda.amp.autocast\n",
    "else:\n",
    "    @contextlib.contextmanager\n",
    "    def autocast():\n",
    "        print(\"AMP was not available, using FP32!\")\n",
    "        yield"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The call to transcribe() below should fail with a \"CUDA out of memory\" error when run on a GPU with 32 GB memory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with autocast():\n",
    "    transcript = model.transcribe([concat_audio_path], batch_size=1)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Clear up memory\n",
    "torch.cuda.empty_cache()\n",
    "gc.collect()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Buffer mechanism for streaming long audio files\n",
    "One way to transcribe long audio with a Conformer-CTC model would be to split the audio into consecutive smaller chunks and running inference on each chunk. Care should be taken to have enough context for audio at either edges for accurate transcription. Let's introduce some terminology here to help us navigate the rest of this tutorial. \n",
    "\n",
    "* Buffer size is the length of audio on which inference is run\n",
    "* Chunk size is the length of new audio that is added to the buffer.\n",
    "\n",
    "An audio buffer is made up of a chunk of audio with some padded audio from previous chunk. In order to make the best predictions with enough context for the beginning and end portions of the buffer, we only collect tokens for the middle portion of the buffer of length equal to the size of each chunk.  \n",
    "\n",
    "Let's suppose that the maximum length of audio that can be transcribed with conformer-large model is 20s, then we can use 20s as the buffer size and use 15s (for example) as the chunk size, so one hour of audio is broken into 240 chunks of 15s each. Let's take a look at a few audio buffers that may be created for this audio."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# A simple iterator class to return successive chunks of samples\n",
    "class AudioChunkIterator():\n",
    "    def __init__(self, samples, frame_len, sample_rate):\n",
    "        self._samples = samples\n",
    "        self._chunk_len = chunk_len_in_secs*sample_rate\n",
    "        self._start = 0\n",
    "        self.output=True\n",
    "   \n",
    "    def __iter__(self):\n",
    "        return self\n",
    "    \n",
    "    def __next__(self):\n",
    "        if not self.output:\n",
    "            raise StopIteration\n",
    "        last = int(self._start + self._chunk_len)\n",
    "        if last <= len(self._samples):\n",
    "            chunk = self._samples[self._start: last]\n",
    "            self._start = last\n",
    "        else:\n",
    "            chunk = np.zeros([int(self._chunk_len)], dtype='float32')\n",
    "            samp_len = len(self._samples) - self._start\n",
    "            chunk[0:samp_len] = self._samples[self._start:len(self._samples)]\n",
    "            self.output = False\n",
    "   \n",
    "        return chunk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# a helper function for extracting samples as a numpy array from the audio file\n",
    "import soundfile as sf\n",
    "def get_samples(audio_file, target_sr=16000):\n",
    "    with sf.SoundFile(audio_file, 'r') as f:\n",
    "        sample_rate = f.samplerate\n",
    "        samples = f.read()\n",
    "        if sample_rate != target_sr:\n",
    "            samples = librosa.core.resample(samples, orig_sr=sample_rate, target_sr=target_sr)\n",
    "        samples = samples.transpose()\n",
    "        return samples\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's take a look at each chunk of speech that is used for decoding."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "samples = get_samples(concat_audio_path)\n",
    "sample_rate  = model.preprocessor._cfg['sample_rate'] \n",
    "chunk_len_in_secs = 1            \n",
    "chunk_reader = AudioChunkIterator(samples, chunk_len_in_secs, sample_rate)\n",
    "count = 0\n",
    "for chunk in chunk_reader:\n",
    "    count +=1\n",
    "    plt.plot(chunk)\n",
    "    plt.show()\n",
    "    if count >= 5:\n",
    "        break\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's plot the actual buffers at each stage after a new chunk is added to the buffer. Audio buffer can be thought of as a fixed size queue with each incoming chunk added at the end of the buffer and the oldest samples removed from the beginning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "context_len_in_secs = 1\n",
    "\n",
    "buffer_len_in_secs = chunk_len_in_secs + 2* context_len_in_secs\n",
    "\n",
    "buffer_len = sample_rate*buffer_len_in_secs\n",
    "sampbuffer = np.zeros([buffer_len], dtype=np.float32)\n",
    "\n",
    "chunk_reader = AudioChunkIterator(samples, chunk_len_in_secs, sample_rate)\n",
    "chunk_len = sample_rate*chunk_len_in_secs\n",
    "count = 0\n",
    "for chunk in chunk_reader:\n",
    "    count +=1\n",
    "    sampbuffer[:-chunk_len] = sampbuffer[chunk_len:]\n",
    "    sampbuffer[-chunk_len:] = chunk\n",
    "    plt.plot(sampbuffer)\n",
    "    plt.show()\n",
    "    if count >= 5:\n",
    "        break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have a method to split the long audio into smaller chunks, we can now work on transcribing the individual buffers and merging the outputs to get the transcription of the whole audio.\n",
    "First, we implement some helper functions to help load the buffers into the data layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo.core.classes import IterableDataset\n",
    "\n",
    "def speech_collate_fn(batch):\n",
    "    \"\"\"collate batch of audio sig, audio len\n",
    "    Args:\n",
    "        batch (FloatTensor, LongTensor):  A tuple of tuples of signal, signal lengths.\n",
    "        This collate func assumes the signals are 1d torch tensors (i.e. mono audio).\n",
    "    \"\"\"\n",
    "\n",
    "    _, audio_lengths = zip(*batch)\n",
    "\n",
    "    max_audio_len = 0\n",
    "    has_audio = audio_lengths[0] is not None\n",
    "    if has_audio:\n",
    "        max_audio_len = max(audio_lengths).item()\n",
    "   \n",
    "    \n",
    "    audio_signal= []\n",
    "    for sig, sig_len in batch:\n",
    "        if has_audio:\n",
    "            sig_len = sig_len.item()\n",
    "            if sig_len < max_audio_len:\n",
    "                pad = (0, max_audio_len - sig_len)\n",
    "                sig = torch.nn.functional.pad(sig, pad)\n",
    "            audio_signal.append(sig)\n",
    "        \n",
    "    if has_audio:\n",
    "        audio_signal = torch.stack(audio_signal)\n",
    "        audio_lengths = torch.stack(audio_lengths)\n",
    "    else:\n",
    "        audio_signal, audio_lengths = None, None\n",
    "\n",
    "    return audio_signal, audio_lengths\n",
    "\n",
    "# simple data layer to pass audio signal\n",
    "class AudioBuffersDataLayer(IterableDataset):\n",
    "    \n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        \n",
    "    def __iter__(self):\n",
    "        return self\n",
    "    \n",
    "    def __next__(self):\n",
    "        if self._buf_count == len(self.signal) :\n",
    "            raise StopIteration\n",
    "        self._buf_count +=1\n",
    "        return torch.as_tensor(self.signal[self._buf_count-1], dtype=torch.float32), \\\n",
    "               torch.as_tensor(self.signal_shape[0], dtype=torch.int64)\n",
    "        \n",
    "    def set_signal(self, signals):\n",
    "        self.signal = signals\n",
    "        self.signal_shape = self.signal[0].shape\n",
    "        self._buf_count = 0\n",
    "\n",
    "    def __len__(self):\n",
    "        return 1\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we implement a class that implements transcribing audio buffers and merging the tokens corresponding to a chunk of audio within each buffer. \n",
    "\n",
    "For each buffer, we pick tokens corresponding to one chunk length of audio. The chunk within each buffer is chosen such that there is equal left and right context available to the audio within the chunk.\n",
    "\n",
    "For example, if the chunk size is 1s and buffer size is 3s, we collect tokens corresponding to audio starting from 1s to 2s within each buffer. Conformer-CTC models have a model stride of 4, i.e., a token is produced for every 4 feature vectors in the time domain. MelSpectrogram features are generated once every 10 ms, so a token is produced for every 40 ms of audio.\n",
    "\n",
    "**Note:** The inherent assumption here is that the output tokens from the model are well aligned with corresponding audio segments. This may not always be true for models trained with CTC loss, so this method of streaming inference may not always work with CTC based models. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "import math\n",
    "class ChunkBufferDecoder:\n",
    "\n",
    "    def __init__(self,asr_model, stride, chunk_len_in_secs=1, buffer_len_in_secs=3):\n",
    "        self.asr_model = asr_model\n",
    "        self.asr_model.eval()\n",
    "        self.data_layer = AudioBuffersDataLayer()\n",
    "        self.data_loader = DataLoader(self.data_layer, batch_size=1, collate_fn=speech_collate_fn)\n",
    "        self.buffers = []\n",
    "        self.all_preds = []\n",
    "        self.chunk_len = chunk_len_in_secs\n",
    "        self.buffer_len = buffer_len_in_secs\n",
    "        assert(chunk_len_in_secs<=buffer_len_in_secs)\n",
    "        \n",
    "        feature_stride = asr_model._cfg.preprocessor['window_stride']\n",
    "        self.model_stride_in_secs = feature_stride * stride\n",
    "        self.n_tokens_per_chunk = math.ceil(self.chunk_len / self.model_stride_in_secs)\n",
    "        self.blank_id = len(asr_model.decoder.vocabulary)\n",
    "        self.plot=False\n",
    "        \n",
    "    @torch.no_grad()    \n",
    "    def transcribe_buffers(self, buffers, merge=True, plot=False):\n",
    "        self.plot = plot\n",
    "        self.buffers = buffers\n",
    "        self.data_layer.set_signal(buffers[:])\n",
    "        self._get_batch_preds()      \n",
    "        return self.decode_final(merge)\n",
    "    \n",
    "    def _get_batch_preds(self):\n",
    "\n",
    "        device = self.asr_model.device\n",
    "        for batch in iter(self.data_loader):\n",
    "\n",
    "            audio_signal, audio_signal_len = batch\n",
    "\n",
    "            audio_signal, audio_signal_len = audio_signal.to(device), audio_signal_len.to(device)\n",
    "            log_probs, encoded_len, predictions = self.asr_model(input_signal=audio_signal, input_signal_length=audio_signal_len)\n",
    "            preds = torch.unbind(predictions)\n",
    "            for pred in preds:\n",
    "                self.all_preds.append(pred.cpu().numpy())\n",
    "    \n",
    "    def decode_final(self, merge=True, extra=0):\n",
    "        self.unmerged = []\n",
    "        self.toks_unmerged = []\n",
    "        # index for the first token corresponding to a chunk of audio would be len(decoded) - 1 - delay\n",
    "        delay = math.ceil((self.chunk_len + (self.buffer_len - self.chunk_len) / 2) / self.model_stride_in_secs)\n",
    "\n",
    "        decoded_frames = []\n",
    "        all_toks = []\n",
    "        for pred in self.all_preds:\n",
    "            ids, toks = self._greedy_decoder(pred, self.asr_model.tokenizer)\n",
    "            decoded_frames.append(ids)\n",
    "            all_toks.append(toks)\n",
    "\n",
    "        for decoded in decoded_frames:\n",
    "            self.unmerged += decoded[len(decoded) - 1 - delay:len(decoded) - 1 - delay + self.n_tokens_per_chunk]\n",
    "        if self.plot:\n",
    "            for i, tok in enumerate(all_toks):\n",
    "                plt.plot(self.buffers[i])\n",
    "                plt.show()\n",
    "                print(\"\\nGreedy labels collected from this buffer\")\n",
    "                print(tok[len(tok) - 1 - delay:len(tok) - 1 - delay + self.n_tokens_per_chunk])                \n",
    "                self.toks_unmerged += tok[len(tok) - 1 - delay:len(tok) - 1 - delay + self.n_tokens_per_chunk]\n",
    "            print(\"\\nTokens collected from succesive buffers before CTC merge\")\n",
    "            print(self.toks_unmerged)\n",
    "\n",
    "\n",
    "        if not merge:\n",
    "            return self.unmerged\n",
    "        return self.greedy_merge(self.unmerged)\n",
    "    \n",
    "    \n",
    "    def _greedy_decoder(self, preds, tokenizer):\n",
    "        s = []\n",
    "        ids = []\n",
    "        for i in range(preds.shape[0]):\n",
    "            if preds[i] == self.blank_id:\n",
    "                s.append(\"_\")\n",
    "            else:\n",
    "                pred = preds[i]\n",
    "                s.append(tokenizer.ids_to_tokens([pred.item()])[0])\n",
    "            ids.append(preds[i])\n",
    "        return ids, s\n",
    "         \n",
    "    def greedy_merge(self, preds):\n",
    "        decoded_prediction = []\n",
    "        previous = self.blank_id\n",
    "        for p in preds:\n",
    "            if (p != previous or previous == self.blank_id) and p != self.blank_id:\n",
    "                decoded_prediction.append(p.item())\n",
    "            previous = p\n",
    "        hypothesis = self.asr_model.tokenizer.ids_to_text(decoded_prediction)\n",
    "        return hypothesis\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To see how this chunk based decoder comes together, let's call the decoder with a few buffers we create from our long audio file.\n",
    "Some interesting experiments to try would be to see how changing sizes of the chunk and the context affects transcription accuracy.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "chunk_len_in_secs = 4\n",
    "context_len_in_secs = 2\n",
    "\n",
    "buffer_len_in_secs = chunk_len_in_secs + 2* context_len_in_secs\n",
    "\n",
    "n_buffers = 5\n",
    "\n",
    "buffer_len = sample_rate*buffer_len_in_secs\n",
    "sampbuffer = np.zeros([buffer_len], dtype=np.float32)\n",
    "\n",
    "chunk_reader = AudioChunkIterator(samples, chunk_len_in_secs, sample_rate)\n",
    "chunk_len = sample_rate*chunk_len_in_secs\n",
    "count = 0\n",
    "buffer_list = []\n",
    "for chunk in chunk_reader:\n",
    "    count +=1\n",
    "    sampbuffer[:-chunk_len] = sampbuffer[chunk_len:]\n",
    "    sampbuffer[-chunk_len:] = chunk\n",
    "    buffer_list.append(np.array(sampbuffer))\n",
    "   \n",
    "    if count >= n_buffers:\n",
    "        break\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stride = 4 # 8 for Citrinet\n",
    "asr_decoder = ChunkBufferDecoder(model, stride=stride, chunk_len_in_secs=chunk_len_in_secs, buffer_len_in_secs=buffer_len_in_secs )\n",
    "transcription = asr_decoder.transcribe_buffers(buffer_list, plot=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Final transcription after CTC merge\n",
    "print(transcription)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Time to evaluate our streaming inference on the whole long file that we created."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# WER calculation\n",
    "from nemo.collections.asr.metrics.wer import word_error_rate\n",
    "# Collect all buffers from the audio file\n",
    "sampbuffer = np.zeros([buffer_len], dtype=np.float32)\n",
    "\n",
    "chunk_reader = AudioChunkIterator(samples, chunk_len_in_secs, sample_rate)\n",
    "buffer_list = []\n",
    "for chunk in chunk_reader:\n",
    "    sampbuffer[:-chunk_len] = sampbuffer[chunk_len:]\n",
    "    sampbuffer[-chunk_len:] = chunk\n",
    "    buffer_list.append(np.array(sampbuffer))\n",
    "\n",
    "asr_decoder = ChunkBufferDecoder(model, stride=stride, chunk_len_in_secs=chunk_len_in_secs, buffer_len_in_secs=buffer_len_in_secs )\n",
    "transcription = asr_decoder.transcribe_buffers(buffer_list, plot=False)\n",
    "wer = word_error_rate(hypotheses=[transcription], references=[ref_transcript])\n",
    "\n",
    "print(f\"WER: {round(wer*100,2)}%\")"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "ASR_with_NeMo.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}