asr-demo / app.py
GavinHuang's picture
fix: move numpy and soundfile imports inside transcribe function and clean up audio buffer handling
692769a
raw
history blame
5.61 kB
import os
import gradio as gr
import torch
import nemo.collections.asr as nemo_asr
from omegaconf import OmegaConf
import time
import spaces
import librosa
# Important: Don't initialize CUDA in the main process for Spaces
# The model will be loaded in the worker process through the GPU decorator
model = None
def load_model():
# This function will be called in the GPU worker process
global model
if model is None:
print(f"Loading model in worker process")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained("nvidia/parakeet-tdt-0.6b-v2")
print(f"Model loaded on device: {model.device}")
return model
@spaces.GPU(duration=120)
def transcribe(audio, state=""):
# Load the model inside the GPU worker process
import numpy as np
import soundfile as sf
model = load_model()
if audio is None or isinstance(audio, int):
print(f"Skipping invalid audio input: {type(audio)}")
return state, state
print(f"Received audio input of type: {type(audio)}")
print(f"Audio shape: {audio.shape if isinstance(audio, np.ndarray) else 'N/A'}")
# Append NumPy array to buffer
# if isinstance(audio, tuple):
# print(f"Tuple contents: {audio}")
# # Try extracting the first element
# audio = audio[1] if len(audio) > 1 else None
if isinstance(audio, tuple) and len(audio) == 2 and isinstance(audio[1], np.ndarray):
# Handle tuple of (sample_rate, audio_array)
print(f"Tuple contents: {audio}")
sample_rate, audio_data = audio
try: # Resample to 16kHz for NeMo
if sample_rate != 16000:
print(f"Resampling from {sample_rate}Hz to 16000Hz")
audio_data = librosa.resample(audio_data.astype(float), orig_sr=sample_rate, target_sr=16000)
# Save to temporary WAV file
temp_file = "temp_audio.wav"
sf.write(temp_file, audio_data, samplerate=16000)
print(f"Processing temporary audio file: {temp_file}")
# Handling NumPy 2.0 compatibility issue
try:
transcription = model.transcribe([temp_file])[0]
except AttributeError as e:
if "np.sctypes" in str(e):
print("Handling NumPy 2.0 compatibility issue")
# Using a workaround to handle the np.sctypes removal
import numpy as np
# Create a temporary sctypes attribute if needed by older code
if not hasattr(np, 'sctypes'):
np.sctypes = {
'int': [np.int8, np.int16, np.int32, np.int64],
'uint': [np.uint8, np.uint16, np.uint32, np.uint64],
'float': [np.float16, np.float32, np.float64],
'complex': [np.complex64, np.complex128]
}
# Try again
transcription = model.transcribe([temp_file])[0]
else:
raise
os.remove(temp_file) # Clean up
print("Temporary file removed.")
except Exception as e:
print(f"Error processing audio: {e}")
# return state, state
new_state = state + " " + transcription if state else transcription
return new_state, new_state
return state, state
# Define the Gradio interface
with gr.Blocks(title="Real-time Speech-to-Text with NeMo") as demo:
gr.Markdown("# πŸŽ™οΈ Real-time Speech-to-Text Transcription")
gr.Markdown("Powered by NVIDIA NeMo and the parakeet-tdt-0.6b-v2 model")
with gr.Row():
with gr.Column(scale=2):
audio_input = gr.Audio(
sources=["microphone"],
type="numpy",
streaming=True,
label="Speak into your microphone"
)
clear_btn = gr.Button("Clear Transcript")
with gr.Column(scale=3):
text_output = gr.Textbox(
label="Transcription",
placeholder="Your speech will appear here...",
lines=10
)
streaming_text = gr.Textbox(
label="Real-time Transcription",
placeholder="Real-time results will appear here...",
lines=2
)
# State to store the ongoing transcription
state = gr.State("")
# Handle the audio stream
audio_input.stream(
fn=transcribe,
inputs=[audio_input, state],
outputs=[state, streaming_text],
)
# Clear the transcription
def clear_transcription():
return "", "", ""
clear_btn.click(
fn=clear_transcription,
inputs=[],
outputs=[text_output, streaming_text, state]
)
# Update the main text output when the state changes
state.change(
fn=lambda s: s,
inputs=[state],
outputs=[text_output]
)
gr.Markdown("## πŸ“ Instructions")
gr.Markdown("""
1. Click the microphone button to start recording
2. Speak clearly into your microphone
3. The transcription will appear in real-time
4. Click 'Clear Transcript' to start a new transcription
""")
# Launch the app
if __name__ == "__main__":
demo.launch()