piecurus commited on
Commit
17ded20
·
1 Parent(s): cab3a0b

fixed bug on max output text len + adding GPU inference

Browse files
Files changed (1) hide show
  1. app.py +10 -9
app.py CHANGED
@@ -12,6 +12,7 @@ import torch
12
  import gradio as gr
13
  from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
14
  nltk.download("punkt")
 
15
 
16
 
17
  # In[ ]:
@@ -22,7 +23,7 @@ model_name = "facebook/wav2vec2-base-960h"
22
 
23
  #model_name = "facebook/wav2vec2-large-xlsr-53"
24
  tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_name)#omdel_name
25
- model = Wav2Vec2ForCTC.from_pretrained(model_name)
26
 
27
 
28
  # In[ ]:
@@ -81,10 +82,10 @@ def asr_transcript_long(input_file,tokenizer=tokenizer, model=model ):
81
  # Ensure that the sample rate is 16k
82
  sample_rate = librosa.get_samplerate(input_file)
83
 
84
- # Stream over 30 seconds chunks rather than load the full file
85
  stream = librosa.stream(
86
  input_file,
87
- block_length=30,
88
  frame_length=sample_rate, #16000,
89
  hop_length=sample_rate, #16000
90
  )
@@ -95,15 +96,15 @@ def asr_transcript_long(input_file,tokenizer=tokenizer, model=model ):
95
  if sample_rate !=16000:
96
  speech = librosa.resample(speech, sample_rate,16000)
97
  input_values = tokenizer(speech, return_tensors="pt").input_values
98
- logits = model(input_values).logits
99
 
100
  predicted_ids = torch.argmax(logits, dim=-1)
101
  transcription = tokenizer.decode(predicted_ids[0])
102
- #transcript += correct_sentence(transcription.lower())
103
  transcript += correct_casing(transcription.lower())
104
- transcript += " "
105
 
106
- return transcript
107
 
108
 
109
  # In[ ]:
@@ -112,8 +113,8 @@ def asr_transcript_long(input_file,tokenizer=tokenizer, model=model ):
112
  gr.Interface(asr_transcript_long,
113
  #inputs = gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Please record your voice"),
114
  inputs = gr.inputs.Audio(source="upload", type="filepath", optional=True, label="Upload your file here"),
115
- outputs = gr.outputs.Textbox(label="Output Text"),
116
- title="ASR using Wav2Vec 2.0",
117
  description = "This application displays transcribed text for given audio input",
118
  examples = [["Test_File1.wav"], ["Test_File2.wav"], ["Test_File3.wav"]], theme="grass").launch()
119
 
 
12
  import gradio as gr
13
  from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
14
  nltk.download("punkt")
15
+ torch_device = 'cuda'
16
 
17
 
18
  # In[ ]:
 
23
 
24
  #model_name = "facebook/wav2vec2-large-xlsr-53"
25
  tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_name)#omdel_name
26
+ model = Wav2Vec2ForCTC.from_pretrained(model_name).to(torch_device)
27
 
28
 
29
  # In[ ]:
 
82
  # Ensure that the sample rate is 16k
83
  sample_rate = librosa.get_samplerate(input_file)
84
 
85
+ # Stream over 10 seconds chunks rather than load the full file
86
  stream = librosa.stream(
87
  input_file,
88
+ block_length=20, #number of seconds to split the batch
89
  frame_length=sample_rate, #16000,
90
  hop_length=sample_rate, #16000
91
  )
 
96
  if sample_rate !=16000:
97
  speech = librosa.resample(speech, sample_rate,16000)
98
  input_values = tokenizer(speech, return_tensors="pt").input_values
99
+ logits = model(input_values.to(torch_device)).logits
100
 
101
  predicted_ids = torch.argmax(logits, dim=-1)
102
  transcription = tokenizer.decode(predicted_ids[0])
103
+ #transcript += transcription.lower()
104
  transcript += correct_casing(transcription.lower())
105
+ #transcript += " "
106
 
107
+ return transcript[:4300]
108
 
109
 
110
  # In[ ]:
 
113
  gr.Interface(asr_transcript_long,
114
  #inputs = gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Please record your voice"),
115
  inputs = gr.inputs.Audio(source="upload", type="filepath", optional=True, label="Upload your file here"),
116
+ outputs = gr.outputs.Textbox(type="str",label="Output Text"),
117
+ title="Transcript and Translate",
118
  description = "This application displays transcribed text for given audio input",
119
  examples = [["Test_File1.wav"], ["Test_File2.wav"], ["Test_File3.wav"]], theme="grass").launch()
120