Jerich commited on
Commit
4763326
·
verified ·
1 Parent(s): fff0177

Modified the code with API endpoints

Browse files
Files changed (1) hide show
  1. app.py +112 -343
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import os
2
  import torch
3
- import gradio as gr
4
  import numpy as np
5
  import soundfile as sf
6
  from transformers import (
@@ -12,14 +11,14 @@ from transformers import (
12
  WhisperProcessor,
13
  WhisperForConditionalGeneration
14
  )
15
- from typing import Optional, Tuple, Dict, List
 
 
 
16
 
17
- class TalklasTranslator:
18
- """
19
- Speech-to-Speech translation pipeline for Philippine languages.
20
- Uses MMS/Whisper for STT, NLLB for MT, and MMS for TTS.
21
- """
22
 
 
23
  LANGUAGE_MAPPING = {
24
  "English": "eng",
25
  "Tagalog": "tgl",
@@ -38,381 +37,151 @@ class TalklasTranslator:
38
  "pag": "pag_Latn"
39
  }
40
 
41
- def __init__(
42
- self,
43
- source_lang: str = "eng",
44
- target_lang: str = "tgl",
45
- device: Optional[str] = None
46
- ):
47
- self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
48
  self.source_lang = source_lang
49
  self.target_lang = target_lang
50
  self.sample_rate = 16000
51
-
52
- print(f"Initializing Talklas Translator on {self.device}")
53
-
54
- # Initialize models
55
  self._initialize_stt_model()
56
  self._initialize_mt_model()
57
  self._initialize_tts_model()
58
 
59
  def _initialize_stt_model(self):
60
- """Initialize speech-to-text model with fallback to Whisper"""
61
  try:
62
- print("Loading STT model...")
63
- try:
64
- # Try loading MMS model first
65
- self.stt_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all")
66
- self.stt_model = AutoModelForCTC.from_pretrained("facebook/mms-1b-all")
67
-
68
- # Set language if available
69
- if self.source_lang in self.stt_processor.tokenizer.vocab.keys():
70
- self.stt_processor.tokenizer.set_target_lang(self.source_lang)
71
- self.stt_model.load_adapter(self.source_lang)
72
- print(f"Loaded MMS STT model for {self.source_lang}")
73
- else:
74
- print(f"Language {self.source_lang} not in MMS, using default")
75
-
76
- except Exception as mms_error:
77
- print(f"MMS loading failed: {mms_error}")
78
- # Fallback to Whisper
79
- print("Loading Whisper as fallback...")
80
- self.stt_processor = WhisperProcessor.from_pretrained("openai/whisper-small")
81
- self.stt_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
82
- print("Loaded Whisper STT model")
83
-
84
  self.stt_model.to(self.device)
85
-
86
  except Exception as e:
87
- print(f"STT model initialization failed: {e}")
88
- raise RuntimeError("Could not initialize STT model")
89
 
90
  def _initialize_mt_model(self):
91
- """Initialize machine translation model"""
92
  try:
93
- print("Loading NLLB Translation model...")
94
  self.mt_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
95
  self.mt_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
96
  self.mt_model.to(self.device)
97
- print("NLLB Translation model loaded")
98
  except Exception as e:
99
- print(f"MT model initialization failed: {e}")
100
- raise
101
 
102
  def _initialize_tts_model(self):
103
- """Initialize text-to-speech model"""
104
  try:
105
- print("Loading TTS model...")
106
- try:
107
- self.tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{self.target_lang}")
108
- self.tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{self.target_lang}")
109
- print(f"Loaded TTS model for {self.target_lang}")
110
- except Exception as tts_error:
111
- print(f"Target language TTS failed: {tts_error}")
112
- print("Falling back to English TTS")
113
- self.tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
114
- self.tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
115
-
116
  self.tts_model.to(self.device)
117
- except Exception as e:
118
- print(f"TTS model initialization failed: {e}")
119
- raise
120
-
121
- def update_languages(self, source_lang: str, target_lang: str) -> str:
122
- """Update languages and reinitialize models if needed"""
123
- if source_lang == self.source_lang and target_lang == self.target_lang:
124
- return "Languages already set"
125
 
 
126
  self.source_lang = source_lang
127
  self.target_lang = target_lang
128
-
129
- # Only reinitialize models that depend on language
130
  self._initialize_stt_model()
131
  self._initialize_tts_model()
132
-
133
  return f"Languages updated to {source_lang} → {target_lang}"
134
 
135
  def speech_to_text(self, audio_path: str) -> str:
136
- """Convert speech to text using loaded STT model"""
137
- try:
138
- waveform, sample_rate = sf.read(audio_path)
139
-
140
- if sample_rate != 16000:
141
- import librosa
142
- waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000)
143
-
144
- inputs = self.stt_processor(
145
- waveform,
146
- sampling_rate=16000,
147
- return_tensors="pt"
148
- ).to(self.device)
149
-
150
- with torch.no_grad():
151
- if isinstance(self.stt_model, WhisperForConditionalGeneration): # Whisper model
152
- generated_ids = self.stt_model.generate(**inputs)
153
- transcription = self.stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
154
- else: # MMS model (Wav2Vec2ForCTC)
155
- logits = self.stt_model(**inputs).logits
156
- predicted_ids = torch.argmax(logits, dim=-1)
157
- transcription = self.stt_processor.batch_decode(predicted_ids)[0]
158
-
159
- return transcription
160
-
161
- except Exception as e:
162
- print(f"Speech recognition failed: {e}")
163
- raise RuntimeError("Speech recognition failed")
164
 
165
  def translate_text(self, text: str) -> str:
166
- """Translate text using NLLB model"""
167
- try:
168
- source_code = self.NLLB_LANGUAGE_CODES[self.source_lang]
169
- target_code = self.NLLB_LANGUAGE_CODES[self.target_lang]
170
-
171
- self.mt_tokenizer.src_lang = source_code
172
- inputs = self.mt_tokenizer(text, return_tensors="pt").to(self.device)
173
-
174
- with torch.no_grad():
175
- generated_tokens = self.mt_model.generate(
176
- **inputs,
177
- forced_bos_token_id=self.mt_tokenizer.convert_tokens_to_ids(target_code),
178
- max_length=448
179
- )
180
-
181
- return self.mt_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
182
-
183
- except Exception as e:
184
- print(f"Translation failed: {e}")
185
- raise RuntimeError("Text translation failed")
186
 
187
  def text_to_speech(self, text: str) -> Tuple[int, np.ndarray]:
188
- """Convert text to speech"""
189
- try:
190
- inputs = self.tts_tokenizer(text, return_tensors="pt").to(self.device)
191
-
192
- with torch.no_grad():
193
- output = self.tts_model(**inputs)
194
-
195
- speech = output.waveform.cpu().numpy().squeeze()
196
- speech = (speech * 32767).astype(np.int16)
197
-
198
- return self.tts_model.config.sampling_rate, speech
199
-
200
- except Exception as e:
201
- print(f"Speech synthesis failed: {e}")
202
- raise RuntimeError("Speech synthesis failed")
203
 
204
  def translate_speech(self, audio_path: str) -> Dict:
205
- """Full speech-to-speech translation"""
206
- try:
207
- source_text = self.speech_to_text(audio_path)
208
- translated_text = self.translate_text(source_text)
209
- sample_rate, audio = self.text_to_speech(translated_text)
210
-
211
- return {
212
- "source_text": source_text,
213
- "translated_text": translated_text,
214
- "output_audio": (sample_rate, audio),
215
- "performance": "Translation successful"
216
- }
217
- except Exception as e:
218
- return {
219
- "source_text": "Error",
220
- "translated_text": "Error",
221
- "output_audio": (16000, np.zeros(1000, dtype=np.int16)),
222
- "performance": f"Error: {str(e)}"
223
- }
224
 
225
  def translate_text_only(self, text: str) -> Dict:
226
- """Text-to-speech translation"""
227
- try:
228
- translated_text = self.translate_text(text)
229
- sample_rate, audio = self.text_to_speech(translated_text)
230
-
231
- return {
232
- "source_text": text,
233
- "translated_text": translated_text,
234
- "output_audio": (sample_rate, audio),
235
- "performance": "Translation successful"
236
- }
237
- except Exception as e:
238
- return {
239
- "source_text": text,
240
- "translated_text": "Error",
241
- "output_audio": (16000, np.zeros(1000, dtype=np.int16)),
242
- "performance": f"Error: {str(e)}"
243
- }
244
-
245
- class TranslatorSingleton:
246
- _instance = None
247
-
248
- @classmethod
249
- def get_instance(cls):
250
- if cls._instance is None:
251
- cls._instance = TalklasTranslator()
252
- return cls._instance
253
-
254
- def process_audio(audio_path, source_lang, target_lang):
255
- """Process audio through the full translation pipeline"""
256
- # Validate input
257
- if not audio_path:
258
- return None, "No audio provided", "No translation available", "Please provide audio input"
259
-
260
- # Update languages
261
- source_code = TalklasTranslator.LANGUAGE_MAPPING[source_lang]
262
- target_code = TalklasTranslator.LANGUAGE_MAPPING[target_lang]
263
-
264
- translator = TranslatorSingleton.get_instance()
265
- status = translator.update_languages(source_code, target_code)
266
-
267
- # Process the audio
268
- results = translator.translate_speech(audio_path)
269
-
270
- return results["output_audio"], results["source_text"], results["translated_text"], results["performance"]
271
-
272
- def process_text(text, source_lang, target_lang):
273
- """Process text through the translation pipeline"""
274
- # Validate input
275
- if not text:
276
- return None, "No text provided", "No translation available", "Please provide text input"
277
-
278
- # Update languages
279
- source_code = TalklasTranslator.LANGUAGE_MAPPING[source_lang]
280
- target_code = TalklasTranslator.LANGUAGE_MAPPING[target_lang]
281
-
282
- translator = TranslatorSingleton.get_instance()
283
- status = translator.update_languages(source_code, target_code)
284
-
285
- # Process the text
286
- results = translator.translate_text_only(text)
287
-
288
- return results["output_audio"], results["source_text"], results["translated_text"], results["performance"]
289
-
290
- def create_gradio_interface():
291
- """Create and launch Gradio interface"""
292
- # Define language options
293
- languages = list(TalklasTranslator.LANGUAGE_MAPPING.keys())
294
-
295
- # Define the interface
296
- demo = gr.Blocks(title="Talklas - Speech & Text Translation")
297
-
298
- with demo:
299
- gr.Markdown("# Talklas: Speech-to-Speech Translation System")
300
- gr.Markdown("### Translate between Philippine Languages and English")
301
-
302
- with gr.Row():
303
- with gr.Column():
304
- source_lang = gr.Dropdown(
305
- choices=languages,
306
- value="English",
307
- label="Source Language"
308
- )
309
-
310
- target_lang = gr.Dropdown(
311
- choices=languages,
312
- value="Tagalog",
313
- label="Target Language"
314
- )
315
-
316
- language_status = gr.Textbox(label="Language Status")
317
- update_btn = gr.Button("Update Languages")
318
-
319
- with gr.Tabs():
320
- with gr.TabItem("Audio Input"):
321
- with gr.Row():
322
- with gr.Column():
323
- gr.Markdown("### Audio Input")
324
- audio_input = gr.Audio(
325
- type="filepath",
326
- label="Upload Audio File"
327
- )
328
- audio_translate_btn = gr.Button("Translate Audio", variant="primary")
329
-
330
- with gr.Column():
331
- gr.Markdown("### Output")
332
- audio_output = gr.Audio(
333
- label="Translated Speech",
334
- type="numpy",
335
- autoplay=True
336
- )
337
-
338
- with gr.TabItem("Text Input"):
339
- with gr.Row():
340
- with gr.Column():
341
- gr.Markdown("### Text Input")
342
- text_input = gr.Textbox(
343
- label="Enter text to translate",
344
- lines=3
345
- )
346
- text_translate_btn = gr.Button("Translate Text", variant="primary")
347
-
348
- with gr.Column():
349
- gr.Markdown("### Output")
350
- text_output = gr.Audio(
351
- label="Translated Speech",
352
- type="numpy",
353
- autoplay=True
354
- )
355
-
356
- with gr.Row():
357
- with gr.Column():
358
- source_text = gr.Textbox(label="Source Text")
359
- translated_text = gr.Textbox(label="Translated Text")
360
- performance_info = gr.Textbox(label="Performance Metrics")
361
-
362
- # Set up events
363
- update_btn.click(
364
- lambda source_lang, target_lang: TranslatorSingleton.get_instance().update_languages(
365
- TalklasTranslator.LANGUAGE_MAPPING[source_lang],
366
- TalklasTranslator.LANGUAGE_MAPPING[target_lang]
367
- ),
368
- inputs=[source_lang, target_lang],
369
- outputs=[language_status]
370
- )
371
-
372
- # Audio translate button click
373
- audio_translate_btn.click(
374
- process_audio,
375
- inputs=[audio_input, source_lang, target_lang],
376
- outputs=[audio_output, source_text, translated_text, performance_info]
377
- ).then(
378
- None,
379
- None,
380
- None,
381
- js="""() => {
382
- const audioElements = document.querySelectorAll('audio');
383
- if (audioElements.length > 0) {
384
- const lastAudio = audioElements[audioElements.length - 1];
385
- lastAudio.play().catch(error => {
386
- console.warn('Autoplay failed:', error);
387
- alert('Audio may require user interaction to play');
388
- });
389
- }
390
- }"""
391
- )
392
-
393
- # Text translate button click
394
- text_translate_btn.click(
395
- process_text,
396
- inputs=[text_input, source_lang, target_lang],
397
- outputs=[text_output, source_text, translated_text, performance_info]
398
- ).then(
399
- None,
400
- None,
401
- None,
402
- js="""() => {
403
- const audioElements = document.querySelectorAll('audio');
404
- if (audioElements.length > 0) {
405
- const lastAudio = audioElements[audioElements.length - 1];
406
- lastAudio.play().catch(error => {
407
- console.warn('Autoplay failed:', error);
408
- alert('Audio may require user interaction to play');
409
- });
410
- }
411
- }"""
412
  )
 
 
 
 
413
 
414
- return demo
 
 
 
 
 
 
 
 
 
 
 
 
415
 
416
  if __name__ == "__main__":
417
- demo = create_gradio_interface()
418
- demo.launch(share=True, debug=True)
 
1
  import os
2
  import torch
 
3
  import numpy as np
4
  import soundfile as sf
5
  from transformers import (
 
11
  WhisperProcessor,
12
  WhisperForConditionalGeneration
13
  )
14
+ from typing import Optional, Tuple, Dict
15
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException
16
+ from fastapi.responses import JSONResponse
17
+ import tempfile
18
 
19
+ app = FastAPI(title="Talklas API")
 
 
 
 
20
 
21
+ class TalklasTranslator:
22
  LANGUAGE_MAPPING = {
23
  "English": "eng",
24
  "Tagalog": "tgl",
 
37
  "pag": "pag_Latn"
38
  }
39
 
40
+ def __init__(self, source_lang: str = "eng", target_lang: str = "tgl"):
41
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
42
  self.source_lang = source_lang
43
  self.target_lang = target_lang
44
  self.sample_rate = 16000
 
 
 
 
45
  self._initialize_stt_model()
46
  self._initialize_mt_model()
47
  self._initialize_tts_model()
48
 
49
  def _initialize_stt_model(self):
 
50
  try:
51
+ self.stt_processor = WhisperProcessor.from_pretrained("openai/whisper-small")
52
+ self.stt_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  self.stt_model.to(self.device)
 
54
  except Exception as e:
55
+ raise RuntimeError(f"STT model initialization failed: {e}")
 
56
 
57
  def _initialize_mt_model(self):
 
58
  try:
 
59
  self.mt_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
60
  self.mt_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
61
  self.mt_model.to(self.device)
 
62
  except Exception as e:
63
+ raise RuntimeError(f"MT model initialization failed: {e}")
 
64
 
65
  def _initialize_tts_model(self):
 
66
  try:
67
+ self.tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{self.target_lang}")
68
+ self.tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{self.target_lang}")
69
+ self.tts_model.to(self.device)
70
+ except Exception:
71
+ self.tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
72
+ self.tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
 
 
 
 
 
73
  self.tts_model.to(self.device)
 
 
 
 
 
 
 
 
74
 
75
+ def update_languages(self, source_lang: str, target_lang: str):
76
  self.source_lang = source_lang
77
  self.target_lang = target_lang
 
 
78
  self._initialize_stt_model()
79
  self._initialize_tts_model()
 
80
  return f"Languages updated to {source_lang} → {target_lang}"
81
 
82
  def speech_to_text(self, audio_path: str) -> str:
83
+ waveform, sample_rate = sf.read(audio_path)
84
+ if sample_rate != 16000:
85
+ import librosa
86
+ waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000)
87
+ inputs = self.stt_processor(waveform, sampling_rate=16000, return_tensors="pt").to(self.device)
88
+ with torch.no_grad():
89
+ generated_ids = self.stt_model.generate(**inputs)
90
+ transcription = self.stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
91
+ return transcription
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  def translate_text(self, text: str) -> str:
94
+ source_code = self.NLLB_LANGUAGE_CODES[self.source_lang]
95
+ target_code = self.NLLB_LANGUAGE_CODES[self.target_lang]
96
+ self.mt_tokenizer.src_lang = source_code
97
+ inputs = self.mt_tokenizer(text, return_tensors="pt").to(self.device)
98
+ with torch.no_grad():
99
+ generated_tokens = self.mt_model.generate(
100
+ **inputs,
101
+ forced_bos_token_id=self.mt_tokenizer.convert_tokens_to_ids(target_code),
102
+ max_length=448
103
+ )
104
+ return self.mt_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
 
 
 
 
 
 
 
 
 
105
 
106
  def text_to_speech(self, text: str) -> Tuple[int, np.ndarray]:
107
+ inputs = self.tts_tokenizer(text, return_tensors="pt").to(self.device)
108
+ with torch.no_grad():
109
+ output = self.tts_model(**inputs)
110
+ speech = output.waveform.cpu().numpy().squeeze()
111
+ speech = (speech * 32767).astype(np.int16)
112
+ return self.tts_model.config.sampling_rate, speech
 
 
 
 
 
 
 
 
 
113
 
114
  def translate_speech(self, audio_path: str) -> Dict:
115
+ source_text = self.speech_to_text(audio_path)
116
+ translated_text = self.translate_text(source_text)
117
+ sample_rate, audio = self.text_to_speech(translated_text)
118
+ return {
119
+ "source_text": source_text,
120
+ "translated_text": translated_text,
121
+ "output_audio": (sample_rate, audio.tolist()), # Convert numpy array to list for JSON
122
+ "performance": "Translation successful"
123
+ }
 
 
 
 
 
 
 
 
 
 
124
 
125
  def translate_text_only(self, text: str) -> Dict:
126
+ translated_text = self.translate_text(text)
127
+ sample_rate, audio = self.text_to_speech(translated_text)
128
+ return {
129
+ "source_text": text,
130
+ "translated_text": translated_text,
131
+ "output_audio": (sample_rate, audio.tolist()),
132
+ "performance": "Translation successful"
133
+ }
134
+
135
+ # Singleton instance
136
+ translator = TalklasTranslator()
137
+
138
+ # API Endpoints
139
+ @app.post("/update-languages")
140
+ async def update_languages(source_lang: str = Form(...), target_lang: str = Form(...)):
141
+ if source_lang not in TalklasTranslator.LANGUAGE_MAPPING or target_lang not in TalklasTranslator.LANGUAGE_MAPPING:
142
+ raise HTTPException(status_code=400, detail="Invalid language selected")
143
+ status = translator.update_languages(
144
+ TalklasTranslator.LANGUAGE_MAPPING[source_lang],
145
+ TalklasTranslator.LANGUAGE_MAPPING[target_lang]
146
+ )
147
+ return {"status": status}
148
+
149
+ @app.post("/translate-audio")
150
+ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form(...), target_lang: str = Form(...)):
151
+ if not audio:
152
+ raise HTTPException(status_code=400, detail="No audio file provided")
153
+ if source_lang not in TalklasTranslator.LANGUAGE_MAPPING or target_lang not in TalklasTranslator.LANGUAGE_MAPPING:
154
+ raise HTTPException(status_code=400, detail="Invalid language selected")
155
+
156
+ # Save uploaded audio temporarily
157
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
158
+ temp_file.write(await audio.read())
159
+ temp_path = temp_file.name
160
+
161
+ try:
162
+ translator.update_languages(
163
+ TalklasTranslator.LANGUAGE_MAPPING[source_lang],
164
+ TalklasTranslator.LANGUAGE_MAPPING[target_lang]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  )
166
+ result = translator.translate_speech(temp_path)
167
+ return JSONResponse(content=result)
168
+ finally:
169
+ os.unlink(temp_path) # Clean up temporary file
170
 
171
+ @app.post("/translate-text")
172
+ async def translate_text(text: str = Form(...), source_lang: str = Form(...), target_lang: str = Form(...)):
173
+ if not text:
174
+ raise HTTPException(status_code=400, detail="No text provided")
175
+ if source_lang not in TalklasTranslator.LANGUAGE_MAPPING or target_lang not in TalklasTranslator.LANGUAGE_MAPPING:
176
+ raise HTTPException(status_code=400, detail="Invalid language selected")
177
+
178
+ translator.update_languages(
179
+ TalklasTranslator.LANGUAGE_MAPPING[source_lang],
180
+ TalklasTranslator.LANGUAGE_MAPPING[target_lang]
181
+ )
182
+ result = translator.translate_text_only(text)
183
+ return JSONResponse(content=result)
184
 
185
  if __name__ == "__main__":
186
+ import uvicorn
187
+ uvicorn.run(app, host="0.0.0.0", port=8000)