GavinHuang commited on
Commit
f334b99
·
1 Parent(s): 4efbce4

refactor model loading and reintroduce GPU decorator for transcription function

Browse files
Files changed (1) hide show
  1. app.py +10 -14
app.py CHANGED
@@ -4,27 +4,18 @@ import torch
4
  import nemo.collections.asr as nemo_asr
5
  from omegaconf import OmegaConf
6
  import time
 
7
 
8
  # Check if CUDA is available
9
  print(f"CUDA available: {torch.cuda.is_available()}")
10
  if torch.cuda.is_available():
11
  print(f"CUDA device: {torch.cuda.get_device_name(0)}")
12
 
13
- # Initialize the ASR model - removed spaces.GPU decorator due to pickling issues
14
- def load_model():
15
- print("Loading ASR model...")
16
- # Load the NVIDIA NeMo ASR model
17
- model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained("nvidia/parakeet-tdt-0.6b-v2")
18
- # Move model to GPU if available
19
- if torch.cuda.is_available():
20
- print(f"CUDA device: {torch.cuda.get_device_name(0)}")
21
- model = model.cuda()
22
- print(f"Model loaded on device: {model.device}")
23
- return model
24
 
25
- # Global variable to store the model
26
- model = load_model()
27
 
 
28
  def transcribe(audio, state=""):
29
  """
30
  Transcribe audio in real-time
@@ -33,6 +24,11 @@ def transcribe(audio, state=""):
33
  if audio is None:
34
  return state, state
35
 
 
 
 
 
 
36
  # Get the sample rate from the audio
37
  sample_rate = 16000 # Default to 16kHz if not specified
38
 
@@ -45,7 +41,7 @@ def transcribe(audio, state=""):
45
  new_state = transcription
46
  else:
47
  new_state = state + " " + transcription
48
-
49
  return new_state, new_state
50
 
51
  # Define the Gradio interface
 
4
  import nemo.collections.asr as nemo_asr
5
  from omegaconf import OmegaConf
6
  import time
7
+ import spaces
8
 
9
  # Check if CUDA is available
10
  print(f"CUDA available: {torch.cuda.is_available()}")
11
  if torch.cuda.is_available():
12
  print(f"CUDA device: {torch.cuda.get_device_name(0)}")
13
 
14
+ model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained("nvidia/parakeet-tdt-0.6b-v2")
 
 
 
 
 
 
 
 
 
 
15
 
16
+ print(f"Model loaded on device: {model.device}")
 
17
 
18
+ @spaces.GPU(duration=120) # Increase duration if inference takes >60s
19
  def transcribe(audio, state=""):
20
  """
21
  Transcribe audio in real-time
 
24
  if audio is None:
25
  return state, state
26
 
27
+ # Move model to GPU if available
28
+ if torch.cuda.is_available():
29
+ print(f"CUDA device: {torch.cuda.get_device_name(0)}")
30
+ model = model.cuda()
31
+
32
  # Get the sample rate from the audio
33
  sample_rate = 16000 # Default to 16kHz if not specified
34
 
 
41
  new_state = transcription
42
  else:
43
  new_state = state + " " + transcription
44
+ model.cpu()
45
  return new_state, new_state
46
 
47
  # Define the Gradio interface