asr-demo / app.py
GavinHuang's picture
fix: streamline transcription process by removing NumPy 2.0 compatibility handling
31b10c8
raw
history blame
4.69 kB
import os
import gradio as gr
import torch
import nemo.collections.asr as nemo_asr
from omegaconf import OmegaConf
import time
import spaces
import librosa
# Important: Don't initialize CUDA in the main process for Spaces
# The model will be loaded in the worker process through the GPU decorator
model = None
def load_model():
# This function will be called in the GPU worker process
global model
if model is None:
print(f"Loading model in worker process")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained("nvidia/parakeet-tdt-0.6b-v2")
print(f"Model loaded on device: {model.device}")
return model
@spaces.GPU(duration=120)
def transcribe(audio, state=""):
# Load the model inside the GPU worker process
import numpy as np
import soundfile as sf
model = load_model()
if audio is None or isinstance(audio, int):
print(f"Skipping invalid audio input: {type(audio)}")
return state, state
print(f"Received audio input of type: {type(audio)}")
print(f"Audio shape: {audio.shape if isinstance(audio, np.ndarray) else 'N/A'}")
# Append NumPy array to buffer
# if isinstance(audio, tuple):
# print(f"Tuple contents: {audio}")
# # Try extracting the first element
# audio = audio[1] if len(audio) > 1 else None
if isinstance(audio, tuple) and len(audio) == 2 and isinstance(audio[1], np.ndarray):
# Handle tuple of (sample_rate, audio_array)
print(f"Tuple contents: {audio}")
sample_rate, audio_data = audio
try: # Resample to 16kHz for NeMo
if sample_rate != 16000:
print(f"Resampling from {sample_rate}Hz to 16000Hz")
audio_data = librosa.resample(audio_data.astype(float), orig_sr=sample_rate, target_sr=16000)
# Save to temporary WAV file
temp_file = "temp_audio.wav"
sf.write(temp_file, audio_data, samplerate=16000)
print(f"Processing temporary audio file: {temp_file}")
transcription = model.transcribe([temp_file])[0]
print(type(transcription))
os.remove(temp_file) # Clean up
print("Temporary file removed.")
except Exception as e:
print(f"Error processing audio: {e}")
# return state, state
new_state = state + " " + transcription if state else transcription
print(new_state)
return new_state, new_state
return state, state
# Define the Gradio interface
with gr.Blocks(title="Real-time Speech-to-Text with NeMo") as demo:
gr.Markdown("# πŸŽ™οΈ Real-time Speech-to-Text Transcription")
gr.Markdown("Powered by NVIDIA NeMo and the parakeet-tdt-0.6b-v2 model")
with gr.Row():
with gr.Column(scale=2):
audio_input = gr.Audio(
sources=["microphone"],
type="numpy",
streaming=True,
label="Speak into your microphone"
)
clear_btn = gr.Button("Clear Transcript")
with gr.Column(scale=3):
text_output = gr.Textbox(
label="Transcription",
placeholder="Your speech will appear here...",
lines=10
)
streaming_text = gr.Textbox(
label="Real-time Transcription",
placeholder="Real-time results will appear here...",
lines=2
)
# State to store the ongoing transcription
state = gr.State("")
# Handle the audio stream
audio_input.stream(
fn=transcribe,
inputs=[audio_input, state],
outputs=[state, streaming_text],
)
# Clear the transcription
def clear_transcription():
return "", "", ""
clear_btn.click(
fn=clear_transcription,
inputs=[],
outputs=[text_output, streaming_text, state]
)
# Update the main text output when the state changes
state.change(
fn=lambda s: s,
inputs=[state],
outputs=[text_output]
)
gr.Markdown("## πŸ“ Instructions")
gr.Markdown("""
1. Click the microphone button to start recording
2. Speak clearly into your microphone
3. The transcription will appear in real-time
4. Click 'Clear Transcript' to start a new transcription
""")
# Launch the app
if __name__ == "__main__":
demo.launch()