Spaces:
Sleeping
Sleeping
import torch | |
import os | |
import numpy as np | |
import tempfile | |
import base64 | |
import gc | |
import sys | |
import traceback | |
import gradio as gr | |
import librosa | |
from scipy.io.wavfile import write | |
from gtts import gTTS | |
import soundfile as sf | |
import whisper # Official OpenAI Whisper package | |
# Define device for processing | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {DEVICE}") | |
# Free up memory | |
gc.collect() | |
if DEVICE == "cuda": | |
torch.cuda.empty_cache() | |
print(f"CUDA memory allocated: {torch.cuda.memory_allocated()/1024**2:.2f} MB") | |
print(f"CUDA memory reserved: {torch.cuda.memory_reserved()/1024**2:.2f} MB") | |
# Try importing transformers, with fallback | |
try: | |
from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
from transformers import BertForSequenceClassification, BertTokenizer, pipeline | |
TRANSFORMERS_AVAILABLE = True | |
print("Transformers package loaded successfully") | |
except Exception as e: | |
TRANSFORMERS_AVAILABLE = False | |
print(f"Warning: Could not import from transformers: {e}") | |
class WhisperTranscriber: | |
def __init__(self, model_size="tiny"): | |
print(f"Initializing Whisper transcriber with model size: {model_size}") | |
self.model_size = model_size | |
self.processor = None | |
self.model = None | |
self.official_model = None | |
# Try to initialize using transformers first | |
if TRANSFORMERS_AVAILABLE: | |
try: | |
print(f"Loading Whisper processor: openai/whisper-{model_size}") | |
self.processor = WhisperProcessor.from_pretrained(f"openai/whisper-{model_size}") | |
print(f"Loading Whisper model: openai/whisper-{model_size}") | |
self.model = WhisperForConditionalGeneration.from_pretrained(f"openai/whisper-{model_size}") | |
if DEVICE == "cuda": | |
print("Moving model to CUDA") | |
self.model = self.model.to(DEVICE) | |
print("Transformers Whisper initialization complete") | |
except Exception as e: | |
print(f"Error initializing Whisper with transformers: {e}") | |
traceback.print_exc() | |
self.processor = None | |
self.model = None | |
# If transformers failed or not available, try official OpenAI implementation | |
if self.processor is None or self.model is None: | |
try: | |
print(f"Falling back to official OpenAI Whisper implementation with model size: {model_size}") | |
self.official_model = whisper.load_model(model_size) | |
print("Official Whisper model loaded successfully") | |
except Exception as e: | |
print(f"Error initializing official Whisper model: {e}") | |
traceback.print_exc() | |
self.official_model = None | |
# Check if any model was loaded | |
if (self.processor is None or self.model is None) and self.official_model is None: | |
print("WARNING: All Whisper initialization attempts failed!") | |
else: | |
print("Whisper initialized successfully with at least one implementation") | |
def transcribe(self, audio_path): | |
# Try transcribing with transformers implementation first | |
if self.processor is not None and self.model is not None: | |
try: | |
print("Transcribing with transformers implementation...") | |
# Load audio | |
waveform, sample_rate = librosa.load(audio_path, sr=16000) | |
# Process audio | |
input_features = self.processor(waveform, sampling_rate=16000, return_tensors="pt").input_features | |
if DEVICE == "cuda": | |
input_features = input_features.to(DEVICE) | |
# Generate transcription | |
with torch.no_grad(): | |
predicted_ids = self.model.generate(input_features, max_length=100) | |
# Decode the transcription | |
transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
print("Transcription successful with transformers implementation") | |
return transcription | |
except Exception as e: | |
print(f"Error in transformers transcription: {e}") | |
traceback.print_exc() | |
# Fall back to official implementation if available | |
if self.official_model is not None: | |
try: | |
print("Falling back to official Whisper implementation...") | |
result = self.official_model.transcribe(audio_path) | |
transcription = result["text"] | |
print("Transcription successful with official implementation") | |
return transcription | |
except Exception as e: | |
print(f"Error in official Whisper transcription: {e}") | |
traceback.print_exc() | |
print("All transcription attempts failed") | |
return "Error: Transcription failed. Please check the logs for details." | |
class GrammarCorrector: | |
def __init__(self): | |
print("Initializing grammar corrector...") | |
try: | |
# Initialize grammar correction pipeline | |
self.corrector = pipeline("text2text-generation", model="pszemraj/grammar-synthesis-small") | |
print("Grammar corrector initialized successfully") | |
except Exception as e: | |
print(f"Error initializing grammar corrector: {e}") | |
traceback.print_exc() | |
self.corrector = None | |
def correct(self, text): | |
if not text or not text.strip(): | |
return text | |
if self.corrector is not None: | |
try: | |
# Use the grammar correction pipeline | |
corrected_text = self.corrector(f"grammar correction: {text}")[0]['generated_text'] | |
return corrected_text | |
except Exception as e: | |
print(f"Error in grammar correction: {e}") | |
return text | |
else: | |
print("No valid grammar correction model available. Returning original text.") | |
return text | |
class TextToSpeech: | |
def __init__(self): | |
print("Initializing text-to-speech engine...") | |
def speak(self, text, output_file="output_speech.mp3"): | |
try: | |
tts = gTTS(text=text, lang='en', slow=False) | |
tts.save(output_file) | |
print(f"Speech saved to {output_file}") | |
return output_file | |
except Exception as e: | |
print(f"Error with gTTS: {e}") | |
traceback.print_exc() | |
return False | |
class SpeechProcessor: | |
def __init__(self, whisper_model_size="tiny"): | |
print(f"Initializing Speech Processor with Whisper model size: {whisper_model_size}") | |
self.transcriber = WhisperTranscriber(model_size=whisper_model_size) | |
self.grammar_corrector = GrammarCorrector() | |
self.tts = TextToSpeech() | |
def process_text(self, text): | |
"""Process text input: correct grammar and generate speech""" | |
print("Processing text input...") | |
# Correct grammar and punctuation | |
corrected_text = self.grammar_corrector.correct(text) | |
# Generate speech from corrected text | |
speech_file = self.tts.speak(corrected_text, "output_speech.mp3") | |
return corrected_text, speech_file | |
def process_audio(self, audio_path): | |
"""Process audio input: transcribe, correct grammar, and generate speech""" | |
print(f"Processing audio input from: {audio_path}") | |
if not audio_path: | |
return "Failed to get audio", None, None | |
# Transcribe audio | |
transcription = self.transcriber.transcribe(audio_path) | |
if transcription.startswith("Error:"): | |
return transcription, None, None | |
# Correct grammar and punctuation | |
corrected_text = self.grammar_corrector.correct(transcription) | |
# Generate speech from corrected text | |
speech_file = self.tts.speak(corrected_text, "output_speech.mp3") | |
return transcription, corrected_text, speech_file | |
# Initialize the processor | |
processor = SpeechProcessor(whisper_model_size="tiny") | |
# Define Gradio functions for the interface | |
def process_text_input(text): | |
"""Handle text input from Gradio interface""" | |
corrected_text, speech_file = processor.process_text(text) | |
return corrected_text, speech_file | |
def process_audio_input(audio_file): | |
"""Handle audio upload/recording from Gradio interface""" | |
if audio_file is None: | |
return "No audio provided", "No audio provided", None | |
transcription, corrected_text, speech_file = processor.process_audio(audio_file) | |
if transcription.startswith("Error:"): | |
return transcription, "", None | |
return transcription, corrected_text, speech_file | |
# Create the Gradio interface | |
def create_gradio_interface(): | |
with gr.Blocks(title="Speech Processing System") as demo: | |
gr.Markdown("# Speech Processing System") | |
gr.Markdown("Transcribe, correct grammar, and generate speech.") | |
with gr.Tab("Text Input"): | |
with gr.Row(): | |
text_input = gr.Textbox(placeholder="Enter text to process", label="Input Text", lines=5) | |
text_button = gr.Button("Process Text") | |
with gr.Row(): | |
corrected_text_output = gr.Textbox(label="Corrected Text", lines=5) | |
speech_output = gr.Audio(label="Speech Output") | |
text_button.click( | |
fn=process_text_input, | |
inputs=[text_input], | |
outputs=[corrected_text_output, speech_output] | |
) | |
with gr.Tab("Audio Input"): | |
with gr.Row(): | |
audio_input = gr.Audio( | |
sources=["microphone", "upload"], | |
type="filepath", | |
label="Upload or Record Audio" | |
) | |
audio_button = gr.Button("Process Audio") | |
with gr.Row(): | |
transcription_output = gr.Textbox(label="Transcription", lines=3) | |
audio_corrected_text = gr.Textbox(label="Corrected Text", lines=3) | |
with gr.Row(): | |
audio_speech_output = gr.Audio(label="Speech Output") | |
audio_button.click( | |
fn=process_audio_input, | |
inputs=[audio_input], | |
outputs=[transcription_output, audio_corrected_text, audio_speech_output] | |
) | |
gr.Markdown("## How to use") | |
gr.Markdown(""" | |
1. **Text Input Tab**: Enter text, click 'Process Text'. The system will correct grammar and generate speech. | |
2. **Audio Input Tab**: Upload an audio file or record using your microphone, then click 'Process Audio'. | |
The system will transcribe your speech, correct grammar, and generate improved speech. | |
""") | |
return demo | |
# Launch the interface | |
demo = create_gradio_interface() | |
if __name__ == "__main__": | |
demo.launch() |