GavinHuang commited on
Commit
a2ed037
·
1 Parent(s): 40ab795

fix: refactor model loading and enhance error handling in transcribe function

Browse files
Files changed (1) hide show
  1. app.py +40 -10
app.py CHANGED
@@ -6,14 +6,22 @@ from omegaconf import OmegaConf
6
  import time
7
  import spaces
8
  import librosa
9
- # Check if CUDA is available
10
- print(f"CUDA available: {torch.cuda.is_available()}")
11
- if torch.cuda.is_available():
12
- print(f"CUDA device: {torch.cuda.get_device_name(0)}")
13
 
14
- model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained("nvidia/parakeet-tdt-0.6b-v2")
 
 
15
 
16
- print(f"Model loaded on device: {model.device}")
 
 
 
 
 
 
 
 
 
 
17
 
18
  import numpy as np
19
  import soundfile as sf
@@ -21,7 +29,9 @@ audio_buffer = []
21
 
22
  @spaces.GPU(duration=120)
23
  def transcribe(audio, state=""):
24
- global model, audio_buffer
 
 
25
  if audio is None or isinstance(audio, int):
26
  print(f"Skipping invalid audio input: {type(audio)}")
27
  return state, state
@@ -38,8 +48,7 @@ def transcribe(audio, state=""):
38
  # Handle tuple of (sample_rate, audio_array)
39
  print(f"Tuple contents: {audio}")
40
  sample_rate, audio_data = audio
41
- try:
42
- # Resample to 16kHz for NeMo
43
  if sample_rate != 16000:
44
  print(f"Resampling from {sample_rate}Hz to 16000Hz")
45
  audio_data = librosa.resample(audio_data.astype(float), orig_sr=sample_rate, target_sr=16000)
@@ -47,7 +56,28 @@ def transcribe(audio, state=""):
47
  temp_file = "temp_audio.wav"
48
  sf.write(temp_file, audio_data, samplerate=16000)
49
  print(f"Processing temporary audio file: {temp_file}")
50
- transcription = model.transcribe([temp_file])[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  os.remove(temp_file) # Clean up
52
  print("Temporary file removed.")
53
  except Exception as e:
 
6
  import time
7
  import spaces
8
  import librosa
 
 
 
 
9
 
10
+ # Important: Don't initialize CUDA in the main process for Spaces
11
+ # The model will be loaded in the worker process through the GPU decorator
12
+ model = None
13
 
14
+ def load_model():
15
+ # This function will be called in the GPU worker process
16
+ global model
17
+ if model is None:
18
+ print(f"Loading model in worker process")
19
+ print(f"CUDA available: {torch.cuda.is_available()}")
20
+ if torch.cuda.is_available():
21
+ print(f"CUDA device: {torch.cuda.get_device_name(0)}")
22
+ model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained("nvidia/parakeet-tdt-0.6b-v2")
23
+ print(f"Model loaded on device: {model.device}")
24
+ return model
25
 
26
  import numpy as np
27
  import soundfile as sf
 
29
 
30
  @spaces.GPU(duration=120)
31
  def transcribe(audio, state=""):
32
+ # Load the model inside the GPU worker process
33
+ model = load_model()
34
+
35
  if audio is None or isinstance(audio, int):
36
  print(f"Skipping invalid audio input: {type(audio)}")
37
  return state, state
 
48
  # Handle tuple of (sample_rate, audio_array)
49
  print(f"Tuple contents: {audio}")
50
  sample_rate, audio_data = audio
51
+ try: # Resample to 16kHz for NeMo
 
52
  if sample_rate != 16000:
53
  print(f"Resampling from {sample_rate}Hz to 16000Hz")
54
  audio_data = librosa.resample(audio_data.astype(float), orig_sr=sample_rate, target_sr=16000)
 
56
  temp_file = "temp_audio.wav"
57
  sf.write(temp_file, audio_data, samplerate=16000)
58
  print(f"Processing temporary audio file: {temp_file}")
59
+
60
+ # Handling NumPy 2.0 compatibility issue
61
+ try:
62
+ transcription = model.transcribe([temp_file])[0]
63
+ except AttributeError as e:
64
+ if "np.sctypes" in str(e):
65
+ print("Handling NumPy 2.0 compatibility issue")
66
+ # Using a workaround to handle the np.sctypes removal
67
+ import numpy as np
68
+ # Create a temporary sctypes attribute if needed by older code
69
+ if not hasattr(np, 'sctypes'):
70
+ np.sctypes = {
71
+ 'int': [np.int8, np.int16, np.int32, np.int64],
72
+ 'uint': [np.uint8, np.uint16, np.uint32, np.uint64],
73
+ 'float': [np.float16, np.float32, np.float64],
74
+ 'complex': [np.complex64, np.complex128]
75
+ }
76
+ # Try again
77
+ transcription = model.transcribe([temp_file])[0]
78
+ else:
79
+ raise
80
+
81
  os.remove(temp_file) # Clean up
82
  print("Temporary file removed.")
83
  except Exception as e: