File size: 11,409 Bytes
f918d7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae12692
f918d7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
import torch
import os
import numpy as np
import tempfile
import base64
import gc
import sys
import traceback
import gradio as gr
import librosa
from scipy.io.wavfile import write
from gtts import gTTS
import soundfile as sf
import whisper  # Official OpenAI Whisper package

# Define device for processing
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

# Free up memory
gc.collect()
if DEVICE == "cuda":
    torch.cuda.empty_cache()
    print(f"CUDA memory allocated: {torch.cuda.memory_allocated()/1024**2:.2f} MB")
    print(f"CUDA memory reserved: {torch.cuda.memory_reserved()/1024**2:.2f} MB")

# Try importing transformers, with fallback
try:
    from transformers import WhisperProcessor, WhisperForConditionalGeneration
    from transformers import BertForSequenceClassification, BertTokenizer, pipeline
    TRANSFORMERS_AVAILABLE = True
    print("Transformers package loaded successfully")
except Exception as e:
    TRANSFORMERS_AVAILABLE = False
    print(f"Warning: Could not import from transformers: {e}")

class WhisperTranscriber:
    def __init__(self, model_size="tiny"):
        print(f"Initializing Whisper transcriber with model size: {model_size}")
        self.model_size = model_size
        self.processor = None
        self.model = None
        self.official_model = None
        
        # Try to initialize using transformers first
        if TRANSFORMERS_AVAILABLE:
            try:
                print(f"Loading Whisper processor: openai/whisper-{model_size}")
                self.processor = WhisperProcessor.from_pretrained(f"openai/whisper-{model_size}")
                
                print(f"Loading Whisper model: openai/whisper-{model_size}")
                self.model = WhisperForConditionalGeneration.from_pretrained(f"openai/whisper-{model_size}")
                
                if DEVICE == "cuda":
                    print("Moving model to CUDA")
                    self.model = self.model.to(DEVICE)
                
                print("Transformers Whisper initialization complete")
            except Exception as e:
                print(f"Error initializing Whisper with transformers: {e}")
                traceback.print_exc()
                self.processor = None
                self.model = None
        
        # If transformers failed or not available, try official OpenAI implementation
        if self.processor is None or self.model is None:
            try:
                print(f"Falling back to official OpenAI Whisper implementation with model size: {model_size}")
                self.official_model = whisper.load_model(model_size)
                print("Official Whisper model loaded successfully")
            except Exception as e:
                print(f"Error initializing official Whisper model: {e}")
                traceback.print_exc()
                self.official_model = None

        # Check if any model was loaded
        if (self.processor is None or self.model is None) and self.official_model is None:
            print("WARNING: All Whisper initialization attempts failed!")
        else:
            print("Whisper initialized successfully with at least one implementation")

    def transcribe(self, audio_path):
        # Try transcribing with transformers implementation first
        if self.processor is not None and self.model is not None:
            try:
                print("Transcribing with transformers implementation...")
                
                # Load audio
                waveform, sample_rate = librosa.load(audio_path, sr=16000)
                
                # Process audio
                input_features = self.processor(waveform, sampling_rate=16000, return_tensors="pt").input_features
                if DEVICE == "cuda":
                    input_features = input_features.to(DEVICE)
                
                # Generate transcription
                with torch.no_grad():
                    predicted_ids = self.model.generate(input_features, max_length=100)
                
                # Decode the transcription
                transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
                print("Transcription successful with transformers implementation")
                return transcription
            
            except Exception as e:
                print(f"Error in transformers transcription: {e}")
                traceback.print_exc()
        
        # Fall back to official implementation if available
        if self.official_model is not None:
            try:
                print("Falling back to official Whisper implementation...")
                result = self.official_model.transcribe(audio_path)
                transcription = result["text"]
                print("Transcription successful with official implementation")
                return transcription
            except Exception as e:
                print(f"Error in official Whisper transcription: {e}")
                traceback.print_exc()
        
        print("All transcription attempts failed")
        return "Error: Transcription failed. Please check the logs for details."

class GrammarCorrector:
    def __init__(self):
        print("Initializing grammar corrector...")
        try:
            # Initialize grammar correction pipeline
            self.corrector = pipeline("text2text-generation", model="pszemraj/grammar-synthesis-small")
            print("Grammar corrector initialized successfully")
        except Exception as e:
            print(f"Error initializing grammar corrector: {e}")
            traceback.print_exc()
            self.corrector = None
    
    def correct(self, text):
        if not text or not text.strip():
            return text
            
        if self.corrector is not None:
            try:
                # Use the grammar correction pipeline
                corrected_text = self.corrector(f"grammar correction: {text}")[0]['generated_text']
                return corrected_text
            except Exception as e:
                print(f"Error in grammar correction: {e}")
                return text
        else:
            print("No valid grammar correction model available. Returning original text.")
            return text

class TextToSpeech:
    def __init__(self):
        print("Initializing text-to-speech engine...")
    
    def speak(self, text, output_file="output_speech.mp3"):
        try:
            tts = gTTS(text=text, lang='en', slow=False)
            tts.save(output_file)
            print(f"Speech saved to {output_file}")
            return output_file
        except Exception as e:
            print(f"Error with gTTS: {e}")
            traceback.print_exc()
            return False

class SpeechProcessor:
    def __init__(self, whisper_model_size="tiny"):
        print(f"Initializing Speech Processor with Whisper model size: {whisper_model_size}")
        self.transcriber = WhisperTranscriber(model_size=whisper_model_size)
        self.grammar_corrector = GrammarCorrector()
        self.tts = TextToSpeech()
    
    def process_text(self, text):
        """Process text input: correct grammar and generate speech"""
        print("Processing text input...")
        
        # Correct grammar and punctuation
        corrected_text = self.grammar_corrector.correct(text)
        
        # Generate speech from corrected text
        speech_file = self.tts.speak(corrected_text, "output_speech.mp3")
        
        return corrected_text, speech_file
    
    def process_audio(self, audio_path):
        """Process audio input: transcribe, correct grammar, and generate speech"""
        print(f"Processing audio input from: {audio_path}")
        
        if not audio_path:
            return "Failed to get audio", None, None
        
        # Transcribe audio
        transcription = self.transcriber.transcribe(audio_path)
        
        if transcription.startswith("Error:"):
            return transcription, None, None
        
        # Correct grammar and punctuation
        corrected_text = self.grammar_corrector.correct(transcription)
        
        # Generate speech from corrected text
        speech_file = self.tts.speak(corrected_text, "output_speech.mp3")
        
        return transcription, corrected_text, speech_file

# Initialize the processor
processor = SpeechProcessor(whisper_model_size="tiny")

# Define Gradio functions for the interface
def process_text_input(text):
    """Handle text input from Gradio interface"""
    corrected_text, speech_file = processor.process_text(text)
    return corrected_text, speech_file

def process_audio_input(audio_file):
    """Handle audio upload/recording from Gradio interface"""
    if audio_file is None:
        return "No audio provided", "No audio provided", None
    
    transcription, corrected_text, speech_file = processor.process_audio(audio_file)
    
    if transcription.startswith("Error:"):
        return transcription, "", None
    
    return transcription, corrected_text, speech_file

# Create the Gradio interface
def create_gradio_interface():
    with gr.Blocks(title="Speech Processing System") as demo:
        gr.Markdown("# Speech Processing System")
        gr.Markdown("Transcribe, correct grammar, and generate speech.")
        
        with gr.Tab("Text Input"):
            with gr.Row():
                text_input = gr.Textbox(placeholder="Enter text to process", label="Input Text", lines=5)
            
            text_button = gr.Button("Process Text")
            
            with gr.Row():
                corrected_text_output = gr.Textbox(label="Corrected Text", lines=5)
                speech_output = gr.Audio(label="Speech Output")
            
            text_button.click(
                fn=process_text_input,
                inputs=[text_input],
                outputs=[corrected_text_output, speech_output]
            )
            
        with gr.Tab("Audio Input"):
            with gr.Row():
                audio_input = gr.Audio(
                    sources=["microphone", "upload"],
                    type="filepath",
                    label="Upload or Record Audio"
                )
            
            audio_button = gr.Button("Process Audio")
            
            with gr.Row():
                transcription_output = gr.Textbox(label="Transcription", lines=3)
                audio_corrected_text = gr.Textbox(label="Corrected Text", lines=3)
            
            with gr.Row():
                audio_speech_output = gr.Audio(label="Speech Output")
            
            audio_button.click(
                fn=process_audio_input,
                inputs=[audio_input],
                outputs=[transcription_output, audio_corrected_text, audio_speech_output]
            )
        
        gr.Markdown("## How to use")
        gr.Markdown("""
        1. **Text Input Tab**: Enter text, click 'Process Text'. The system will correct grammar and generate speech.
        2. **Audio Input Tab**: Upload an audio file or record using your microphone, then click 'Process Audio'. 
           The system will transcribe your speech, correct grammar, and generate improved speech.
        """)
    
    return demo

# Launch the interface
demo = create_gradio_interface()

if __name__ == "__main__":
    demo.launch()