Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
a2ed037
1
Parent(s):
40ab795
fix: refactor model loading and enhance error handling in transcribe function
Browse files
app.py
CHANGED
@@ -6,14 +6,22 @@ from omegaconf import OmegaConf
|
|
6 |
import time
|
7 |
import spaces
|
8 |
import librosa
|
9 |
-
# Check if CUDA is available
|
10 |
-
print(f"CUDA available: {torch.cuda.is_available()}")
|
11 |
-
if torch.cuda.is_available():
|
12 |
-
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
|
13 |
|
14 |
-
|
|
|
|
|
15 |
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
import numpy as np
|
19 |
import soundfile as sf
|
@@ -21,7 +29,9 @@ audio_buffer = []
|
|
21 |
|
22 |
@spaces.GPU(duration=120)
|
23 |
def transcribe(audio, state=""):
|
24 |
-
|
|
|
|
|
25 |
if audio is None or isinstance(audio, int):
|
26 |
print(f"Skipping invalid audio input: {type(audio)}")
|
27 |
return state, state
|
@@ -38,8 +48,7 @@ def transcribe(audio, state=""):
|
|
38 |
# Handle tuple of (sample_rate, audio_array)
|
39 |
print(f"Tuple contents: {audio}")
|
40 |
sample_rate, audio_data = audio
|
41 |
-
try:
|
42 |
-
# Resample to 16kHz for NeMo
|
43 |
if sample_rate != 16000:
|
44 |
print(f"Resampling from {sample_rate}Hz to 16000Hz")
|
45 |
audio_data = librosa.resample(audio_data.astype(float), orig_sr=sample_rate, target_sr=16000)
|
@@ -47,7 +56,28 @@ def transcribe(audio, state=""):
|
|
47 |
temp_file = "temp_audio.wav"
|
48 |
sf.write(temp_file, audio_data, samplerate=16000)
|
49 |
print(f"Processing temporary audio file: {temp_file}")
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
os.remove(temp_file) # Clean up
|
52 |
print("Temporary file removed.")
|
53 |
except Exception as e:
|
|
|
6 |
import time
|
7 |
import spaces
|
8 |
import librosa
|
|
|
|
|
|
|
|
|
9 |
|
10 |
+
# Important: Don't initialize CUDA in the main process for Spaces
|
11 |
+
# The model will be loaded in the worker process through the GPU decorator
|
12 |
+
model = None
|
13 |
|
14 |
+
def load_model():
|
15 |
+
# This function will be called in the GPU worker process
|
16 |
+
global model
|
17 |
+
if model is None:
|
18 |
+
print(f"Loading model in worker process")
|
19 |
+
print(f"CUDA available: {torch.cuda.is_available()}")
|
20 |
+
if torch.cuda.is_available():
|
21 |
+
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
|
22 |
+
model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained("nvidia/parakeet-tdt-0.6b-v2")
|
23 |
+
print(f"Model loaded on device: {model.device}")
|
24 |
+
return model
|
25 |
|
26 |
import numpy as np
|
27 |
import soundfile as sf
|
|
|
29 |
|
30 |
@spaces.GPU(duration=120)
|
31 |
def transcribe(audio, state=""):
|
32 |
+
# Load the model inside the GPU worker process
|
33 |
+
model = load_model()
|
34 |
+
|
35 |
if audio is None or isinstance(audio, int):
|
36 |
print(f"Skipping invalid audio input: {type(audio)}")
|
37 |
return state, state
|
|
|
48 |
# Handle tuple of (sample_rate, audio_array)
|
49 |
print(f"Tuple contents: {audio}")
|
50 |
sample_rate, audio_data = audio
|
51 |
+
try: # Resample to 16kHz for NeMo
|
|
|
52 |
if sample_rate != 16000:
|
53 |
print(f"Resampling from {sample_rate}Hz to 16000Hz")
|
54 |
audio_data = librosa.resample(audio_data.astype(float), orig_sr=sample_rate, target_sr=16000)
|
|
|
56 |
temp_file = "temp_audio.wav"
|
57 |
sf.write(temp_file, audio_data, samplerate=16000)
|
58 |
print(f"Processing temporary audio file: {temp_file}")
|
59 |
+
|
60 |
+
# Handling NumPy 2.0 compatibility issue
|
61 |
+
try:
|
62 |
+
transcription = model.transcribe([temp_file])[0]
|
63 |
+
except AttributeError as e:
|
64 |
+
if "np.sctypes" in str(e):
|
65 |
+
print("Handling NumPy 2.0 compatibility issue")
|
66 |
+
# Using a workaround to handle the np.sctypes removal
|
67 |
+
import numpy as np
|
68 |
+
# Create a temporary sctypes attribute if needed by older code
|
69 |
+
if not hasattr(np, 'sctypes'):
|
70 |
+
np.sctypes = {
|
71 |
+
'int': [np.int8, np.int16, np.int32, np.int64],
|
72 |
+
'uint': [np.uint8, np.uint16, np.uint32, np.uint64],
|
73 |
+
'float': [np.float16, np.float32, np.float64],
|
74 |
+
'complex': [np.complex64, np.complex128]
|
75 |
+
}
|
76 |
+
# Try again
|
77 |
+
transcription = model.transcribe([temp_file])[0]
|
78 |
+
else:
|
79 |
+
raise
|
80 |
+
|
81 |
os.remove(temp_file) # Clean up
|
82 |
print("Temporary file removed.")
|
83 |
except Exception as e:
|