GavinHuang commited on
Commit
2647bd6
Β·
1 Parent(s): 895c600

fix: enhance model loading and selection in transcribe function for improved user experience

Browse files
Files changed (1) hide show
  1. app.py +62 -21
app.py CHANGED
@@ -10,28 +10,42 @@ import librosa
10
  # Important: Don't initialize CUDA in the main process for Spaces
11
  # The model will be loaded in the worker process through the GPU decorator
12
  model = None
 
13
 
14
- def load_model():
 
 
 
15
  # This function will be called in the GPU worker process
16
- global model
17
- if model is None:
18
- print(f"Loading model in worker process")
 
 
 
 
 
19
  print(f"CUDA available: {torch.cuda.is_available()}")
20
  if torch.cuda.is_available():
21
  print(f"CUDA device: {torch.cuda.get_device_name(0)}")
22
- model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained("nvidia/parakeet-tdt-0.6b-v2")
 
 
 
 
 
23
  print(f"Model loaded on device: {model.device}")
 
24
  return model
25
 
26
  @spaces.GPU(duration=120)
27
- def transcribe(audio, state="", audio_buffer=None, last_processed_time=0):
28
  # Load the model inside the GPU worker process
29
  import numpy as np
30
  import soundfile as sf
31
  import librosa
32
  import os
33
- model = load_model()
34
-
35
  if audio_buffer is None:
36
  audio_buffer = []
37
 
@@ -129,7 +143,22 @@ def transcribe(audio, state="", audio_buffer=None, last_processed_time=0):
129
  # Define the Gradio interface
130
  with gr.Blocks(title="Real-time Speech-to-Text with NeMo") as demo:
131
  gr.Markdown("# πŸŽ™οΈ Real-time Speech-to-Text Transcription")
132
- gr.Markdown("Powered by NVIDIA NeMo and the parakeet-tdt-0.6b-v2 model")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  with gr.Row():
134
  with gr.Column(scale=2):
135
  audio_input = gr.Audio(
@@ -152,19 +181,30 @@ with gr.Blocks(title="Real-time Speech-to-Text with NeMo") as demo:
152
  placeholder="Real-time results will appear here...",
153
  lines=2
154
  )
155
-
156
- # State to store the ongoing transcription
157
  state = gr.State("")
158
  audio_buffer = gr.State(value=None)
159
  last_processed_time = gr.State(value=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  # Handle the audio stream
161
  audio_input.stream(
162
  fn=transcribe,
163
- inputs=[audio_input, state, audio_buffer, last_processed_time],
164
  outputs=[state, streaming_text, audio_buffer, last_processed_time],
165
- )
166
-
167
- # Clear the transcription
168
  def clear_transcription():
169
  return "", "", None, 0
170
 
@@ -180,13 +220,14 @@ with gr.Blocks(title="Real-time Speech-to-Text with NeMo") as demo:
180
  inputs=[state],
181
  outputs=[text_output]
182
  )
183
-
184
- gr.Markdown("## πŸ“ Instructions")
185
  gr.Markdown("""
186
- 1. Click the microphone button to start recording
187
- 2. Speak clearly into your microphone
188
- 3. The transcription will appear in real-time
189
- 4. Click 'Clear Transcript' to start a new transcription
 
 
190
  """)
191
 
192
  # Launch the app
 
10
  # Important: Don't initialize CUDA in the main process for Spaces
11
  # The model will be loaded in the worker process through the GPU decorator
12
  model = None
13
+ current_model_name = "nvidia/parakeet-tdt-0.6b-v2"
14
 
15
+ # Available models
16
+ available_models = ["nvidia/parakeet-tdt-0.6b-v2"]
17
+
18
+ def load_model(model_name=None):
19
  # This function will be called in the GPU worker process
20
+ global model, current_model_name
21
+
22
+ # Use the specified model name or the current one
23
+ model_name = model_name or current_model_name
24
+
25
+ # Check if we need to load a new model
26
+ if model is None or model_name != current_model_name:
27
+ print(f"Loading model {model_name} in worker process")
28
  print(f"CUDA available: {torch.cuda.is_available()}")
29
  if torch.cuda.is_available():
30
  print(f"CUDA device: {torch.cuda.get_device_name(0)}")
31
+
32
+ # Update the current model name
33
+ current_model_name = model_name
34
+
35
+ # Load the selected model
36
+ model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(model_name)
37
  print(f"Model loaded on device: {model.device}")
38
+
39
  return model
40
 
41
  @spaces.GPU(duration=120)
42
+ def transcribe(audio, model_name="nvidia/parakeet-tdt-0.6b-v2", state="", audio_buffer=None, last_processed_time=0):
43
  # Load the model inside the GPU worker process
44
  import numpy as np
45
  import soundfile as sf
46
  import librosa
47
  import os
48
+ model = load_model(model_name)
 
49
  if audio_buffer is None:
50
  audio_buffer = []
51
 
 
143
  # Define the Gradio interface
144
  with gr.Blocks(title="Real-time Speech-to-Text with NeMo") as demo:
145
  gr.Markdown("# πŸŽ™οΈ Real-time Speech-to-Text Transcription")
146
+ gr.Markdown("Powered by NVIDIA NeMo")
147
+
148
+ # Model selection and loading
149
+ with gr.Row():
150
+ with gr.Column(scale=3):
151
+ model_dropdown = gr.Dropdown(
152
+ choices=available_models,
153
+ value=current_model_name,
154
+ label="Select ASR Model"
155
+ )
156
+ with gr.Column(scale=1):
157
+ load_button = gr.Button("Load Selected Model")
158
+
159
+ # Status indicator for model loading
160
+ model_status = gr.Textbox(value=f"Current model: {current_model_name}", label="Model Status")
161
+
162
  with gr.Row():
163
  with gr.Column(scale=2):
164
  audio_input = gr.Audio(
 
181
  placeholder="Real-time results will appear here...",
182
  lines=2
183
  )
184
+ # State to store the ongoing transcription
 
185
  state = gr.State("")
186
  audio_buffer = gr.State(value=None)
187
  last_processed_time = gr.State(value=0)
188
+
189
+ # Function to handle model selection
190
+ def update_model(model_name):
191
+ global current_model_name
192
+ current_model_name = model_name
193
+ return f"Current model: {model_name}", None, 0 # Reset audio buffer and last processed time
194
+
195
+ # Load model button event
196
+ load_button.click(
197
+ fn=update_model,
198
+ inputs=[model_dropdown],
199
+ outputs=[model_status, audio_buffer, last_processed_time]
200
+ )
201
+
202
  # Handle the audio stream
203
  audio_input.stream(
204
  fn=transcribe,
205
+ inputs=[audio_input, model_dropdown, state, audio_buffer, last_processed_time],
206
  outputs=[state, streaming_text, audio_buffer, last_processed_time],
207
+ ) # Clear the transcription
 
 
208
  def clear_transcription():
209
  return "", "", None, 0
210
 
 
220
  inputs=[state],
221
  outputs=[text_output]
222
  )
223
+ gr.Markdown("## πŸ“ Instructions")
 
224
  gr.Markdown("""
225
+ 1. Select an ASR model from the dropdown menu
226
+ 2. Click 'Load Selected Model' to load the model
227
+ 3. Click the microphone button to start recording
228
+ 4. Speak clearly into your microphone
229
+ 5. The transcription will appear in real-time
230
+ 6. Click 'Clear Transcript' to start a new transcription
231
  """)
232
 
233
  # Launch the app