GrammarCheck / app.py
KavyaBansal's picture
Update app.py
ae12692 verified
import torch
import os
import numpy as np
import tempfile
import base64
import gc
import sys
import traceback
import gradio as gr
import librosa
from scipy.io.wavfile import write
from gtts import gTTS
import soundfile as sf
import whisper # Official OpenAI Whisper package
# Define device for processing
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")
# Free up memory
gc.collect()
if DEVICE == "cuda":
torch.cuda.empty_cache()
print(f"CUDA memory allocated: {torch.cuda.memory_allocated()/1024**2:.2f} MB")
print(f"CUDA memory reserved: {torch.cuda.memory_reserved()/1024**2:.2f} MB")
# Try importing transformers, with fallback
try:
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from transformers import BertForSequenceClassification, BertTokenizer, pipeline
TRANSFORMERS_AVAILABLE = True
print("Transformers package loaded successfully")
except Exception as e:
TRANSFORMERS_AVAILABLE = False
print(f"Warning: Could not import from transformers: {e}")
class WhisperTranscriber:
def __init__(self, model_size="tiny"):
print(f"Initializing Whisper transcriber with model size: {model_size}")
self.model_size = model_size
self.processor = None
self.model = None
self.official_model = None
# Try to initialize using transformers first
if TRANSFORMERS_AVAILABLE:
try:
print(f"Loading Whisper processor: openai/whisper-{model_size}")
self.processor = WhisperProcessor.from_pretrained(f"openai/whisper-{model_size}")
print(f"Loading Whisper model: openai/whisper-{model_size}")
self.model = WhisperForConditionalGeneration.from_pretrained(f"openai/whisper-{model_size}")
if DEVICE == "cuda":
print("Moving model to CUDA")
self.model = self.model.to(DEVICE)
print("Transformers Whisper initialization complete")
except Exception as e:
print(f"Error initializing Whisper with transformers: {e}")
traceback.print_exc()
self.processor = None
self.model = None
# If transformers failed or not available, try official OpenAI implementation
if self.processor is None or self.model is None:
try:
print(f"Falling back to official OpenAI Whisper implementation with model size: {model_size}")
self.official_model = whisper.load_model(model_size)
print("Official Whisper model loaded successfully")
except Exception as e:
print(f"Error initializing official Whisper model: {e}")
traceback.print_exc()
self.official_model = None
# Check if any model was loaded
if (self.processor is None or self.model is None) and self.official_model is None:
print("WARNING: All Whisper initialization attempts failed!")
else:
print("Whisper initialized successfully with at least one implementation")
def transcribe(self, audio_path):
# Try transcribing with transformers implementation first
if self.processor is not None and self.model is not None:
try:
print("Transcribing with transformers implementation...")
# Load audio
waveform, sample_rate = librosa.load(audio_path, sr=16000)
# Process audio
input_features = self.processor(waveform, sampling_rate=16000, return_tensors="pt").input_features
if DEVICE == "cuda":
input_features = input_features.to(DEVICE)
# Generate transcription
with torch.no_grad():
predicted_ids = self.model.generate(input_features, max_length=100)
# Decode the transcription
transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
print("Transcription successful with transformers implementation")
return transcription
except Exception as e:
print(f"Error in transformers transcription: {e}")
traceback.print_exc()
# Fall back to official implementation if available
if self.official_model is not None:
try:
print("Falling back to official Whisper implementation...")
result = self.official_model.transcribe(audio_path)
transcription = result["text"]
print("Transcription successful with official implementation")
return transcription
except Exception as e:
print(f"Error in official Whisper transcription: {e}")
traceback.print_exc()
print("All transcription attempts failed")
return "Error: Transcription failed. Please check the logs for details."
class GrammarCorrector:
def __init__(self):
print("Initializing grammar corrector...")
try:
# Initialize grammar correction pipeline
self.corrector = pipeline("text2text-generation", model="pszemraj/grammar-synthesis-small")
print("Grammar corrector initialized successfully")
except Exception as e:
print(f"Error initializing grammar corrector: {e}")
traceback.print_exc()
self.corrector = None
def correct(self, text):
if not text or not text.strip():
return text
if self.corrector is not None:
try:
# Use the grammar correction pipeline
corrected_text = self.corrector(f"grammar correction: {text}")[0]['generated_text']
return corrected_text
except Exception as e:
print(f"Error in grammar correction: {e}")
return text
else:
print("No valid grammar correction model available. Returning original text.")
return text
class TextToSpeech:
def __init__(self):
print("Initializing text-to-speech engine...")
def speak(self, text, output_file="output_speech.mp3"):
try:
tts = gTTS(text=text, lang='en', slow=False)
tts.save(output_file)
print(f"Speech saved to {output_file}")
return output_file
except Exception as e:
print(f"Error with gTTS: {e}")
traceback.print_exc()
return False
class SpeechProcessor:
def __init__(self, whisper_model_size="tiny"):
print(f"Initializing Speech Processor with Whisper model size: {whisper_model_size}")
self.transcriber = WhisperTranscriber(model_size=whisper_model_size)
self.grammar_corrector = GrammarCorrector()
self.tts = TextToSpeech()
def process_text(self, text):
"""Process text input: correct grammar and generate speech"""
print("Processing text input...")
# Correct grammar and punctuation
corrected_text = self.grammar_corrector.correct(text)
# Generate speech from corrected text
speech_file = self.tts.speak(corrected_text, "output_speech.mp3")
return corrected_text, speech_file
def process_audio(self, audio_path):
"""Process audio input: transcribe, correct grammar, and generate speech"""
print(f"Processing audio input from: {audio_path}")
if not audio_path:
return "Failed to get audio", None, None
# Transcribe audio
transcription = self.transcriber.transcribe(audio_path)
if transcription.startswith("Error:"):
return transcription, None, None
# Correct grammar and punctuation
corrected_text = self.grammar_corrector.correct(transcription)
# Generate speech from corrected text
speech_file = self.tts.speak(corrected_text, "output_speech.mp3")
return transcription, corrected_text, speech_file
# Initialize the processor
processor = SpeechProcessor(whisper_model_size="tiny")
# Define Gradio functions for the interface
def process_text_input(text):
"""Handle text input from Gradio interface"""
corrected_text, speech_file = processor.process_text(text)
return corrected_text, speech_file
def process_audio_input(audio_file):
"""Handle audio upload/recording from Gradio interface"""
if audio_file is None:
return "No audio provided", "No audio provided", None
transcription, corrected_text, speech_file = processor.process_audio(audio_file)
if transcription.startswith("Error:"):
return transcription, "", None
return transcription, corrected_text, speech_file
# Create the Gradio interface
def create_gradio_interface():
with gr.Blocks(title="Speech Processing System") as demo:
gr.Markdown("# Speech Processing System")
gr.Markdown("Transcribe, correct grammar, and generate speech.")
with gr.Tab("Text Input"):
with gr.Row():
text_input = gr.Textbox(placeholder="Enter text to process", label="Input Text", lines=5)
text_button = gr.Button("Process Text")
with gr.Row():
corrected_text_output = gr.Textbox(label="Corrected Text", lines=5)
speech_output = gr.Audio(label="Speech Output")
text_button.click(
fn=process_text_input,
inputs=[text_input],
outputs=[corrected_text_output, speech_output]
)
with gr.Tab("Audio Input"):
with gr.Row():
audio_input = gr.Audio(
sources=["microphone", "upload"],
type="filepath",
label="Upload or Record Audio"
)
audio_button = gr.Button("Process Audio")
with gr.Row():
transcription_output = gr.Textbox(label="Transcription", lines=3)
audio_corrected_text = gr.Textbox(label="Corrected Text", lines=3)
with gr.Row():
audio_speech_output = gr.Audio(label="Speech Output")
audio_button.click(
fn=process_audio_input,
inputs=[audio_input],
outputs=[transcription_output, audio_corrected_text, audio_speech_output]
)
gr.Markdown("## How to use")
gr.Markdown("""
1. **Text Input Tab**: Enter text, click 'Process Text'. The system will correct grammar and generate speech.
2. **Audio Input Tab**: Upload an audio file or record using your microphone, then click 'Process Audio'.
The system will transcribe your speech, correct grammar, and generate improved speech.
""")
return demo
# Launch the interface
demo = create_gradio_interface()
if __name__ == "__main__":
demo.launch()