Jerich commited on
Commit
f53ba4b
·
verified ·
1 Parent(s): 716acc0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +418 -0
app.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ import numpy as np
5
+ import soundfile as sf
6
+ from transformers import (
7
+ AutoModelForSeq2SeqLM,
8
+ AutoTokenizer,
9
+ VitsModel,
10
+ AutoProcessor,
11
+ AutoModelForCTC,
12
+ WhisperProcessor,
13
+ WhisperForConditionalGeneration
14
+ )
15
+ from typing import Optional, Tuple, Dict, List
16
+
17
+ class TalklasTranslator:
18
+ """
19
+ Speech-to-Speech translation pipeline for Philippine languages.
20
+ Uses MMS/Whisper for STT, NLLB for MT, and MMS for TTS.
21
+ """
22
+
23
+ LANGUAGE_MAPPING = {
24
+ "English": "eng",
25
+ "Tagalog": "tgl",
26
+ "Cebuano": "ceb",
27
+ "Ilocano": "ilo",
28
+ "Waray": "war",
29
+ "Pangasinan": "pag"
30
+ }
31
+
32
+ NLLB_LANGUAGE_CODES = {
33
+ "eng": "eng_Latn",
34
+ "tgl": "tgl_Latn",
35
+ "ceb": "ceb_Latn",
36
+ "ilo": "ilo_Latn",
37
+ "war": "war_Latn",
38
+ "pag": "pag_Latn"
39
+ }
40
+
41
+ def __init__(
42
+ self,
43
+ source_lang: str = "eng",
44
+ target_lang: str = "tgl",
45
+ device: Optional[str] = None
46
+ ):
47
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
48
+ self.source_lang = source_lang
49
+ self.target_lang = target_lang
50
+ self.sample_rate = 16000
51
+
52
+ print(f"Initializing Talklas Translator on {self.device}")
53
+
54
+ # Initialize models
55
+ self._initialize_stt_model()
56
+ self._initialize_mt_model()
57
+ self._initialize_tts_model()
58
+
59
+ def _initialize_stt_model(self):
60
+ """Initialize speech-to-text model with fallback to Whisper"""
61
+ try:
62
+ print("Loading STT model...")
63
+ try:
64
+ # Try loading MMS model first
65
+ self.stt_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all")
66
+ self.stt_model = AutoModelForCTC.from_pretrained("facebook/mms-1b-all")
67
+
68
+ # Set language if available
69
+ if self.source_lang in self.stt_processor.tokenizer.vocab.keys():
70
+ self.stt_processor.tokenizer.set_target_lang(self.source_lang)
71
+ self.stt_model.load_adapter(self.source_lang)
72
+ print(f"Loaded MMS STT model for {self.source_lang}")
73
+ else:
74
+ print(f"Language {self.source_lang} not in MMS, using default")
75
+
76
+ except Exception as mms_error:
77
+ print(f"MMS loading failed: {mms_error}")
78
+ # Fallback to Whisper
79
+ print("Loading Whisper as fallback...")
80
+ self.stt_processor = WhisperProcessor.from_pretrained("openai/whisper-small")
81
+ self.stt_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
82
+ print("Loaded Whisper STT model")
83
+
84
+ self.stt_model.to(self.device)
85
+
86
+ except Exception as e:
87
+ print(f"STT model initialization failed: {e}")
88
+ raise RuntimeError("Could not initialize STT model")
89
+
90
+ def _initialize_mt_model(self):
91
+ """Initialize machine translation model"""
92
+ try:
93
+ print("Loading NLLB Translation model...")
94
+ self.mt_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
95
+ self.mt_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
96
+ self.mt_model.to(self.device)
97
+ print("NLLB Translation model loaded")
98
+ except Exception as e:
99
+ print(f"MT model initialization failed: {e}")
100
+ raise
101
+
102
+ def _initialize_tts_model(self):
103
+ """Initialize text-to-speech model"""
104
+ try:
105
+ print("Loading TTS model...")
106
+ try:
107
+ self.tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{self.target_lang}")
108
+ self.tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{self.target_lang}")
109
+ print(f"Loaded TTS model for {self.target_lang}")
110
+ except Exception as tts_error:
111
+ print(f"Target language TTS failed: {tts_error}")
112
+ print("Falling back to English TTS")
113
+ self.tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
114
+ self.tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
115
+
116
+ self.tts_model.to(self.device)
117
+ except Exception as e:
118
+ print(f"TTS model initialization failed: {e}")
119
+ raise
120
+
121
+ def update_languages(self, source_lang: str, target_lang: str) -> str:
122
+ """Update languages and reinitialize models if needed"""
123
+ if source_lang == self.source_lang and target_lang == self.target_lang:
124
+ return "Languages already set"
125
+
126
+ self.source_lang = source_lang
127
+ self.target_lang = target_lang
128
+
129
+ # Only reinitialize models that depend on language
130
+ self._initialize_stt_model()
131
+ self._initialize_tts_model()
132
+
133
+ return f"Languages updated to {source_lang} → {target_lang}"
134
+
135
+ def speech_to_text(self, audio_path: str) -> str:
136
+ """Convert speech to text using loaded STT model"""
137
+ try:
138
+ waveform, sample_rate = sf.read(audio_path)
139
+
140
+ if sample_rate != 16000:
141
+ import librosa
142
+ waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000)
143
+
144
+ inputs = self.stt_processor(
145
+ waveform,
146
+ sampling_rate=16000,
147
+ return_tensors="pt"
148
+ ).to(self.device)
149
+
150
+ with torch.no_grad():
151
+ if isinstance(self.stt_model, WhisperForConditionalGeneration): # Whisper model
152
+ generated_ids = self.stt_model.generate(**inputs)
153
+ transcription = self.stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
154
+ else: # MMS model (Wav2Vec2ForCTC)
155
+ logits = self.stt_model(**inputs).logits
156
+ predicted_ids = torch.argmax(logits, dim=-1)
157
+ transcription = self.stt_processor.batch_decode(predicted_ids)[0]
158
+
159
+ return transcription
160
+
161
+ except Exception as e:
162
+ print(f"Speech recognition failed: {e}")
163
+ raise RuntimeError("Speech recognition failed")
164
+
165
+ def translate_text(self, text: str) -> str:
166
+ """Translate text using NLLB model"""
167
+ try:
168
+ source_code = self.NLLB_LANGUAGE_CODES[self.source_lang]
169
+ target_code = self.NLLB_LANGUAGE_CODES[self.target_lang]
170
+
171
+ self.mt_tokenizer.src_lang = source_code
172
+ inputs = self.mt_tokenizer(text, return_tensors="pt").to(self.device)
173
+
174
+ with torch.no_grad():
175
+ generated_tokens = self.mt_model.generate(
176
+ **inputs,
177
+ forced_bos_token_id=self.mt_tokenizer.convert_tokens_to_ids(target_code),
178
+ max_length=448
179
+ )
180
+
181
+ return self.mt_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
182
+
183
+ except Exception as e:
184
+ print(f"Translation failed: {e}")
185
+ raise RuntimeError("Text translation failed")
186
+
187
+ def text_to_speech(self, text: str) -> Tuple[int, np.ndarray]:
188
+ """Convert text to speech"""
189
+ try:
190
+ inputs = self.tts_tokenizer(text, return_tensors="pt").to(self.device)
191
+
192
+ with torch.no_grad():
193
+ output = self.tts_model(**inputs)
194
+
195
+ speech = output.waveform.cpu().numpy().squeeze()
196
+ speech = (speech * 32767).astype(np.int16)
197
+
198
+ return self.tts_model.config.sampling_rate, speech
199
+
200
+ except Exception as e:
201
+ print(f"Speech synthesis failed: {e}")
202
+ raise RuntimeError("Speech synthesis failed")
203
+
204
+ def translate_speech(self, audio_path: str) -> Dict:
205
+ """Full speech-to-speech translation"""
206
+ try:
207
+ source_text = self.speech_to_text(audio_path)
208
+ translated_text = self.translate_text(source_text)
209
+ sample_rate, audio = self.text_to_speech(translated_text)
210
+
211
+ return {
212
+ "source_text": source_text,
213
+ "translated_text": translated_text,
214
+ "output_audio": (sample_rate, audio),
215
+ "performance": "Translation successful"
216
+ }
217
+ except Exception as e:
218
+ return {
219
+ "source_text": "Error",
220
+ "translated_text": "Error",
221
+ "output_audio": (16000, np.zeros(1000, dtype=np.int16)),
222
+ "performance": f"Error: {str(e)}"
223
+ }
224
+
225
+ def translate_text_only(self, text: str) -> Dict:
226
+ """Text-to-speech translation"""
227
+ try:
228
+ translated_text = self.translate_text(text)
229
+ sample_rate, audio = self.text_to_speech(translated_text)
230
+
231
+ return {
232
+ "source_text": text,
233
+ "translated_text": translated_text,
234
+ "output_audio": (sample_rate, audio),
235
+ "performance": "Translation successful"
236
+ }
237
+ except Exception as e:
238
+ return {
239
+ "source_text": text,
240
+ "translated_text": "Error",
241
+ "output_audio": (16000, np.zeros(1000, dtype=np.int16)),
242
+ "performance": f"Error: {str(e)}"
243
+ }
244
+
245
+ class TranslatorSingleton:
246
+ _instance = None
247
+
248
+ @classmethod
249
+ def get_instance(cls):
250
+ if cls._instance is None:
251
+ cls._instance = TalklasTranslator()
252
+ return cls._instance
253
+
254
+ def process_audio(audio_path, source_lang, target_lang):
255
+ """Process audio through the full translation pipeline"""
256
+ # Validate input
257
+ if not audio_path:
258
+ return None, "No audio provided", "No translation available", "Please provide audio input"
259
+
260
+ # Update languages
261
+ source_code = TalklasTranslator.LANGUAGE_MAPPING[source_lang]
262
+ target_code = TalklasTranslator.LANGUAGE_MAPPING[target_lang]
263
+
264
+ translator = TranslatorSingleton.get_instance()
265
+ status = translator.update_languages(source_code, target_code)
266
+
267
+ # Process the audio
268
+ results = translator.translate_speech(audio_path)
269
+
270
+ return results["output_audio"], results["source_text"], results["translated_text"], results["performance"]
271
+
272
+ def process_text(text, source_lang, target_lang):
273
+ """Process text through the translation pipeline"""
274
+ # Validate input
275
+ if not text:
276
+ return None, "No text provided", "No translation available", "Please provide text input"
277
+
278
+ # Update languages
279
+ source_code = TalklasTranslator.LANGUAGE_MAPPING[source_lang]
280
+ target_code = TalklasTranslator.LANGUAGE_MAPPING[target_lang]
281
+
282
+ translator = TranslatorSingleton.get_instance()
283
+ status = translator.update_languages(source_code, target_code)
284
+
285
+ # Process the text
286
+ results = translator.translate_text_only(text)
287
+
288
+ return results["output_audio"], results["source_text"], results["translated_text"], results["performance"]
289
+
290
+ def create_gradio_interface():
291
+ """Create and launch Gradio interface"""
292
+ # Define language options
293
+ languages = list(TalklasTranslator.LANGUAGE_MAPPING.keys())
294
+
295
+ # Define the interface
296
+ demo = gr.Blocks(title="Talklas - Speech & Text Translation")
297
+
298
+ with demo:
299
+ gr.Markdown("# Talklas: Speech-to-Speech Translation System")
300
+ gr.Markdown("### Translate between Philippine Languages and English")
301
+
302
+ with gr.Row():
303
+ with gr.Column():
304
+ source_lang = gr.Dropdown(
305
+ choices=languages,
306
+ value="English",
307
+ label="Source Language"
308
+ )
309
+
310
+ target_lang = gr.Dropdown(
311
+ choices=languages,
312
+ value="Tagalog",
313
+ label="Target Language"
314
+ )
315
+
316
+ language_status = gr.Textbox(label="Language Status")
317
+ update_btn = gr.Button("Update Languages")
318
+
319
+ with gr.Tabs():
320
+ with gr.TabItem("Audio Input"):
321
+ with gr.Row():
322
+ with gr.Column():
323
+ gr.Markdown("### Audio Input")
324
+ audio_input = gr.Audio(
325
+ type="filepath",
326
+ label="Upload Audio File"
327
+ )
328
+ audio_translate_btn = gr.Button("Translate Audio", variant="primary")
329
+
330
+ with gr.Column():
331
+ gr.Markdown("### Output")
332
+ audio_output = gr.Audio(
333
+ label="Translated Speech",
334
+ type="numpy",
335
+ autoplay=True
336
+ )
337
+
338
+ with gr.TabItem("Text Input"):
339
+ with gr.Row():
340
+ with gr.Column():
341
+ gr.Markdown("### Text Input")
342
+ text_input = gr.Textbox(
343
+ label="Enter text to translate",
344
+ lines=3
345
+ )
346
+ text_translate_btn = gr.Button("Translate Text", variant="primary")
347
+
348
+ with gr.Column():
349
+ gr.Markdown("### Output")
350
+ text_output = gr.Audio(
351
+ label="Translated Speech",
352
+ type="numpy",
353
+ autoplay=True
354
+ )
355
+
356
+ with gr.Row():
357
+ with gr.Column():
358
+ source_text = gr.Textbox(label="Source Text")
359
+ translated_text = gr.Textbox(label="Translated Text")
360
+ performance_info = gr.Textbox(label="Performance Metrics")
361
+
362
+ # Set up events
363
+ update_btn.click(
364
+ lambda source_lang, target_lang: TranslatorSingleton.get_instance().update_languages(
365
+ TalklasTranslator.LANGUAGE_MAPPING[source_lang],
366
+ TalklasTranslator.LANGUAGE_MAPPING[target_lang]
367
+ ),
368
+ inputs=[source_lang, target_lang],
369
+ outputs=[language_status]
370
+ )
371
+
372
+ # Audio translate button click
373
+ audio_translate_btn.click(
374
+ process_audio,
375
+ inputs=[audio_input, source_lang, target_lang],
376
+ outputs=[audio_output, source_text, translated_text, performance_info]
377
+ ).then(
378
+ None,
379
+ None,
380
+ None,
381
+ js="""() => {
382
+ const audioElements = document.querySelectorAll('audio');
383
+ if (audioElements.length > 0) {
384
+ const lastAudio = audioElements[audioElements.length - 1];
385
+ lastAudio.play().catch(error => {
386
+ console.warn('Autoplay failed:', error);
387
+ alert('Audio may require user interaction to play');
388
+ });
389
+ }
390
+ }"""
391
+ )
392
+
393
+ # Text translate button click
394
+ text_translate_btn.click(
395
+ process_text,
396
+ inputs=[text_input, source_lang, target_lang],
397
+ outputs=[text_output, source_text, translated_text, performance_info]
398
+ ).then(
399
+ None,
400
+ None,
401
+ None,
402
+ js="""() => {
403
+ const audioElements = document.querySelectorAll('audio');
404
+ if (audioElements.length > 0) {
405
+ const lastAudio = audioElements[audioElements.length - 1];
406
+ lastAudio.play().catch(error => {
407
+ console.warn('Autoplay failed:', error);
408
+ alert('Audio may require user interaction to play');
409
+ });
410
+ }
411
+ }"""
412
+ )
413
+
414
+ return demo
415
+
416
+ if __name__ == "__main__":
417
+ demo = create_gradio_interface()
418
+ demo.launch(server_name="0.0.0.0", server_port=7860)