KavyaBansal commited on
Commit
f918d7f
·
verified ·
1 Parent(s): ea64729

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +286 -0
app.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import numpy as np
4
+ import tempfile
5
+ import base64
6
+ import gc
7
+ import sys
8
+ import traceback
9
+ import gradio as gr
10
+ import librosa
11
+ from scipy.io.wavfile import write
12
+ from gtts import gTTS
13
+ import soundfile as sf
14
+ import whisper # Official OpenAI Whisper package
15
+
16
+ # Define device for processing
17
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
+ print(f"Using device: {DEVICE}")
19
+
20
+ # Free up memory
21
+ gc.collect()
22
+ if DEVICE == "cuda":
23
+ torch.cuda.empty_cache()
24
+ print(f"CUDA memory allocated: {torch.cuda.memory_allocated()/1024**2:.2f} MB")
25
+ print(f"CUDA memory reserved: {torch.cuda.memory_reserved()/1024**2:.2f} MB")
26
+
27
+ # Try importing transformers, with fallback
28
+ try:
29
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
30
+ from transformers import BertForSequenceClassification, BertTokenizer, pipeline
31
+ TRANSFORMERS_AVAILABLE = True
32
+ print("Transformers package loaded successfully")
33
+ except Exception as e:
34
+ TRANSFORMERS_AVAILABLE = False
35
+ print(f"Warning: Could not import from transformers: {e}")
36
+
37
+ class WhisperTranscriber:
38
+ def __init__(self, model_size="tiny"):
39
+ print(f"Initializing Whisper transcriber with model size: {model_size}")
40
+ self.model_size = model_size
41
+ self.processor = None
42
+ self.model = None
43
+ self.official_model = None
44
+
45
+ # Try to initialize using transformers first
46
+ if TRANSFORMERS_AVAILABLE:
47
+ try:
48
+ print(f"Loading Whisper processor: openai/whisper-{model_size}")
49
+ self.processor = WhisperProcessor.from_pretrained(f"openai/whisper-{model_size}")
50
+
51
+ print(f"Loading Whisper model: openai/whisper-{model_size}")
52
+ self.model = WhisperForConditionalGeneration.from_pretrained(f"openai/whisper-{model_size}")
53
+
54
+ if DEVICE == "cuda":
55
+ print("Moving model to CUDA")
56
+ self.model = self.model.to(DEVICE)
57
+
58
+ print("Transformers Whisper initialization complete")
59
+ except Exception as e:
60
+ print(f"Error initializing Whisper with transformers: {e}")
61
+ traceback.print_exc()
62
+ self.processor = None
63
+ self.model = None
64
+
65
+ # If transformers failed or not available, try official OpenAI implementation
66
+ if self.processor is None or self.model is None:
67
+ try:
68
+ print(f"Falling back to official OpenAI Whisper implementation with model size: {model_size}")
69
+ self.official_model = whisper.load_model(model_size)
70
+ print("Official Whisper model loaded successfully")
71
+ except Exception as e:
72
+ print(f"Error initializing official Whisper model: {e}")
73
+ traceback.print_exc()
74
+ self.official_model = None
75
+
76
+ # Check if any model was loaded
77
+ if (self.processor is None or self.model is None) and self.official_model is None:
78
+ print("WARNING: All Whisper initialization attempts failed!")
79
+ else:
80
+ print("Whisper initialized successfully with at least one implementation")
81
+
82
+ def transcribe(self, audio_path):
83
+ # Try transcribing with transformers implementation first
84
+ if self.processor is not None and self.model is not None:
85
+ try:
86
+ print("Transcribing with transformers implementation...")
87
+
88
+ # Load audio
89
+ waveform, sample_rate = librosa.load(audio_path, sr=16000)
90
+
91
+ # Process audio
92
+ input_features = self.processor(waveform, sampling_rate=16000, return_tensors="pt").input_features
93
+ if DEVICE == "cuda":
94
+ input_features = input_features.to(DEVICE)
95
+
96
+ # Generate transcription
97
+ with torch.no_grad():
98
+ predicted_ids = self.model.generate(input_features, max_length=100)
99
+
100
+ # Decode the transcription
101
+ transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
102
+ print("Transcription successful with transformers implementation")
103
+ return transcription
104
+
105
+ except Exception as e:
106
+ print(f"Error in transformers transcription: {e}")
107
+ traceback.print_exc()
108
+
109
+ # Fall back to official implementation if available
110
+ if self.official_model is not None:
111
+ try:
112
+ print("Falling back to official Whisper implementation...")
113
+ result = self.official_model.transcribe(audio_path)
114
+ transcription = result["text"]
115
+ print("Transcription successful with official implementation")
116
+ return transcription
117
+ except Exception as e:
118
+ print(f"Error in official Whisper transcription: {e}")
119
+ traceback.print_exc()
120
+
121
+ print("All transcription attempts failed")
122
+ return "Error: Transcription failed. Please check the logs for details."
123
+
124
+ class GrammarCorrector:
125
+ def __init__(self):
126
+ print("Initializing grammar corrector...")
127
+ try:
128
+ # Initialize grammar correction pipeline
129
+ self.corrector = pipeline("text2text-generation", model="pszemraj/flan-t5-large-grammar-synthesis")
130
+ print("Grammar corrector initialized successfully")
131
+ except Exception as e:
132
+ print(f"Error initializing grammar corrector: {e}")
133
+ traceback.print_exc()
134
+ self.corrector = None
135
+
136
+ def correct(self, text):
137
+ if not text or not text.strip():
138
+ return text
139
+
140
+ if self.corrector is not None:
141
+ try:
142
+ # Use the grammar correction pipeline
143
+ corrected_text = self.corrector(f"grammar correction: {text}")[0]['generated_text']
144
+ return corrected_text
145
+ except Exception as e:
146
+ print(f"Error in grammar correction: {e}")
147
+ return text
148
+ else:
149
+ print("No valid grammar correction model available. Returning original text.")
150
+ return text
151
+
152
+ class TextToSpeech:
153
+ def __init__(self):
154
+ print("Initializing text-to-speech engine...")
155
+
156
+ def speak(self, text, output_file="output_speech.mp3"):
157
+ try:
158
+ tts = gTTS(text=text, lang='en', slow=False)
159
+ tts.save(output_file)
160
+ print(f"Speech saved to {output_file}")
161
+ return output_file
162
+ except Exception as e:
163
+ print(f"Error with gTTS: {e}")
164
+ traceback.print_exc()
165
+ return False
166
+
167
+ class SpeechProcessor:
168
+ def __init__(self, whisper_model_size="tiny"):
169
+ print(f"Initializing Speech Processor with Whisper model size: {whisper_model_size}")
170
+ self.transcriber = WhisperTranscriber(model_size=whisper_model_size)
171
+ self.grammar_corrector = GrammarCorrector()
172
+ self.tts = TextToSpeech()
173
+
174
+ def process_text(self, text):
175
+ """Process text input: correct grammar and generate speech"""
176
+ print("Processing text input...")
177
+
178
+ # Correct grammar and punctuation
179
+ corrected_text = self.grammar_corrector.correct(text)
180
+
181
+ # Generate speech from corrected text
182
+ speech_file = self.tts.speak(corrected_text, "output_speech.mp3")
183
+
184
+ return corrected_text, speech_file
185
+
186
+ def process_audio(self, audio_path):
187
+ """Process audio input: transcribe, correct grammar, and generate speech"""
188
+ print(f"Processing audio input from: {audio_path}")
189
+
190
+ if not audio_path:
191
+ return "Failed to get audio", None, None
192
+
193
+ # Transcribe audio
194
+ transcription = self.transcriber.transcribe(audio_path)
195
+
196
+ if transcription.startswith("Error:"):
197
+ return transcription, None, None
198
+
199
+ # Correct grammar and punctuation
200
+ corrected_text = self.grammar_corrector.correct(transcription)
201
+
202
+ # Generate speech from corrected text
203
+ speech_file = self.tts.speak(corrected_text, "output_speech.mp3")
204
+
205
+ return transcription, corrected_text, speech_file
206
+
207
+ # Initialize the processor
208
+ processor = SpeechProcessor(whisper_model_size="tiny")
209
+
210
+ # Define Gradio functions for the interface
211
+ def process_text_input(text):
212
+ """Handle text input from Gradio interface"""
213
+ corrected_text, speech_file = processor.process_text(text)
214
+ return corrected_text, speech_file
215
+
216
+ def process_audio_input(audio_file):
217
+ """Handle audio upload/recording from Gradio interface"""
218
+ if audio_file is None:
219
+ return "No audio provided", "No audio provided", None
220
+
221
+ transcription, corrected_text, speech_file = processor.process_audio(audio_file)
222
+
223
+ if transcription.startswith("Error:"):
224
+ return transcription, "", None
225
+
226
+ return transcription, corrected_text, speech_file
227
+
228
+ # Create the Gradio interface
229
+ def create_gradio_interface():
230
+ with gr.Blocks(title="Speech Processing System") as demo:
231
+ gr.Markdown("# Speech Processing System")
232
+ gr.Markdown("Transcribe, correct grammar, and generate speech.")
233
+
234
+ with gr.Tab("Text Input"):
235
+ with gr.Row():
236
+ text_input = gr.Textbox(placeholder="Enter text to process", label="Input Text", lines=5)
237
+
238
+ text_button = gr.Button("Process Text")
239
+
240
+ with gr.Row():
241
+ corrected_text_output = gr.Textbox(label="Corrected Text", lines=5)
242
+ speech_output = gr.Audio(label="Speech Output")
243
+
244
+ text_button.click(
245
+ fn=process_text_input,
246
+ inputs=[text_input],
247
+ outputs=[corrected_text_output, speech_output]
248
+ )
249
+
250
+ with gr.Tab("Audio Input"):
251
+ with gr.Row():
252
+ audio_input = gr.Audio(
253
+ sources=["microphone", "upload"],
254
+ type="filepath",
255
+ label="Upload or Record Audio"
256
+ )
257
+
258
+ audio_button = gr.Button("Process Audio")
259
+
260
+ with gr.Row():
261
+ transcription_output = gr.Textbox(label="Transcription", lines=3)
262
+ audio_corrected_text = gr.Textbox(label="Corrected Text", lines=3)
263
+
264
+ with gr.Row():
265
+ audio_speech_output = gr.Audio(label="Speech Output")
266
+
267
+ audio_button.click(
268
+ fn=process_audio_input,
269
+ inputs=[audio_input],
270
+ outputs=[transcription_output, audio_corrected_text, audio_speech_output]
271
+ )
272
+
273
+ gr.Markdown("## How to use")
274
+ gr.Markdown("""
275
+ 1. **Text Input Tab**: Enter text, click 'Process Text'. The system will correct grammar and generate speech.
276
+ 2. **Audio Input Tab**: Upload an audio file or record using your microphone, then click 'Process Audio'.
277
+ The system will transcribe your speech, correct grammar, and generate improved speech.
278
+ """)
279
+
280
+ return demo
281
+
282
+ # Launch the interface
283
+ demo = create_gradio_interface()
284
+
285
+ if __name__ == "__main__":
286
+ demo.launch()