Jerich commited on
Commit
224fa8d
·
verified ·
1 Parent(s): e978acd

Expose the Hugging Face Code as an API

Browse files
Files changed (1) hide show
  1. app.py +91 -209
app.py CHANGED
@@ -1,8 +1,10 @@
1
  import os
2
  import torch
3
- import gradio as gr
4
  import numpy as np
5
  import soundfile as sf
 
 
 
6
  from transformers import (
7
  AutoModelForSeq2SeqLM,
8
  AutoTokenizer,
@@ -13,13 +15,11 @@ from transformers import (
13
  WhisperForConditionalGeneration
14
  )
15
  from typing import Optional, Tuple, Dict, List
 
 
16
 
 
17
  class TalklasTranslator:
18
- """
19
- Speech-to-Speech translation pipeline for Philippine languages.
20
- Uses MMS/Whisper for STT, NLLB for MT, and MMS for TTS.
21
- """
22
-
23
  LANGUAGE_MAPPING = {
24
  "English": "eng",
25
  "Tagalog": "tgl",
@@ -50,45 +50,34 @@ class TalklasTranslator:
50
  self.sample_rate = 16000
51
 
52
  print(f"Initializing Talklas Translator on {self.device}")
53
-
54
- # Initialize models
55
  self._initialize_stt_model()
56
  self._initialize_mt_model()
57
  self._initialize_tts_model()
58
 
59
  def _initialize_stt_model(self):
60
- """Initialize speech-to-text model with fallback to Whisper"""
61
  try:
62
  print("Loading STT model...")
63
  try:
64
- # Try loading MMS model first
65
  self.stt_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all")
66
  self.stt_model = AutoModelForCTC.from_pretrained("facebook/mms-1b-all")
67
-
68
- # Set language if available
69
  if self.source_lang in self.stt_processor.tokenizer.vocab.keys():
70
  self.stt_processor.tokenizer.set_target_lang(self.source_lang)
71
  self.stt_model.load_adapter(self.source_lang)
72
  print(f"Loaded MMS STT model for {self.source_lang}")
73
  else:
74
  print(f"Language {self.source_lang} not in MMS, using default")
75
-
76
  except Exception as mms_error:
77
  print(f"MMS loading failed: {mms_error}")
78
- # Fallback to Whisper
79
  print("Loading Whisper as fallback...")
80
  self.stt_processor = WhisperProcessor.from_pretrained("openai/whisper-small")
81
  self.stt_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
82
  print("Loaded Whisper STT model")
83
-
84
  self.stt_model.to(self.device)
85
-
86
  except Exception as e:
87
  print(f"STT model initialization failed: {e}")
88
  raise RuntimeError("Could not initialize STT model")
89
 
90
  def _initialize_mt_model(self):
91
- """Initialize machine translation model"""
92
  try:
93
  print("Loading NLLB Translation model...")
94
  self.mt_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
@@ -100,7 +89,6 @@ class TalklasTranslator:
100
  raise
101
 
102
  def _initialize_tts_model(self):
103
- """Initialize text-to-speech model"""
104
  try:
105
  print("Loading TTS model...")
106
  try:
@@ -112,102 +100,78 @@ class TalklasTranslator:
112
  print("Falling back to English TTS")
113
  self.tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
114
  self.tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
115
-
116
  self.tts_model.to(self.device)
117
  except Exception as e:
118
  print(f"TTS model initialization failed: {e}")
119
  raise
120
 
121
  def update_languages(self, source_lang: str, target_lang: str) -> str:
122
- """Update languages and reinitialize models if needed"""
123
  if source_lang == self.source_lang and target_lang == self.target_lang:
124
  return "Languages already set"
125
-
126
  self.source_lang = source_lang
127
  self.target_lang = target_lang
128
-
129
- # Only reinitialize models that depend on language
130
  self._initialize_stt_model()
131
  self._initialize_tts_model()
132
-
133
  return f"Languages updated to {source_lang} → {target_lang}"
134
 
135
  def speech_to_text(self, audio_path: str) -> str:
136
- """Convert speech to text using loaded STT model"""
137
  try:
138
  waveform, sample_rate = sf.read(audio_path)
139
-
140
  if sample_rate != 16000:
141
  import librosa
142
  waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000)
143
-
144
  inputs = self.stt_processor(
145
  waveform,
146
  sampling_rate=16000,
147
  return_tensors="pt"
148
  ).to(self.device)
149
-
150
  with torch.no_grad():
151
- if isinstance(self.stt_model, WhisperForConditionalGeneration): # Whisper model
152
  generated_ids = self.stt_model.generate(**inputs)
153
  transcription = self.stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
154
- else: # MMS model (Wav2Vec2ForCTC)
155
  logits = self.stt_model(**inputs).logits
156
  predicted_ids = torch.argmax(logits, dim=-1)
157
  transcription = self.stt_processor.batch_decode(predicted_ids)[0]
158
-
159
  return transcription
160
-
161
  except Exception as e:
162
  print(f"Speech recognition failed: {e}")
163
  raise RuntimeError("Speech recognition failed")
164
 
165
  def translate_text(self, text: str) -> str:
166
- """Translate text using NLLB model"""
167
  try:
168
  source_code = self.NLLB_LANGUAGE_CODES[self.source_lang]
169
  target_code = self.NLLB_LANGUAGE_CODES[self.target_lang]
170
-
171
  self.mt_tokenizer.src_lang = source_code
172
  inputs = self.mt_tokenizer(text, return_tensors="pt").to(self.device)
173
-
174
  with torch.no_grad():
175
  generated_tokens = self.mt_model.generate(
176
  **inputs,
177
  forced_bos_token_id=self.mt_tokenizer.convert_tokens_to_ids(target_code),
178
  max_length=448
179
  )
180
-
181
  return self.mt_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
182
-
183
  except Exception as e:
184
  print(f"Translation failed: {e}")
185
  raise RuntimeError("Text translation failed")
186
 
187
  def text_to_speech(self, text: str) -> Tuple[int, np.ndarray]:
188
- """Convert text to speech"""
189
  try:
190
  inputs = self.tts_tokenizer(text, return_tensors="pt").to(self.device)
191
-
192
  with torch.no_grad():
193
  output = self.tts_model(**inputs)
194
-
195
  speech = output.waveform.cpu().numpy().squeeze()
196
  speech = (speech * 32767).astype(np.int16)
197
-
198
  return self.tts_model.config.sampling_rate, speech
199
-
200
  except Exception as e:
201
  print(f"Speech synthesis failed: {e}")
202
  raise RuntimeError("Speech synthesis failed")
203
 
204
  def translate_speech(self, audio_path: str) -> Dict:
205
- """Full speech-to-speech translation"""
206
  try:
207
  source_text = self.speech_to_text(audio_path)
208
  translated_text = self.translate_text(source_text)
209
  sample_rate, audio = self.text_to_speech(translated_text)
210
-
211
  return {
212
  "source_text": source_text,
213
  "translated_text": translated_text,
@@ -223,11 +187,9 @@ class TalklasTranslator:
223
  }
224
 
225
  def translate_text_only(self, text: str) -> Dict:
226
- """Text-to-speech translation"""
227
  try:
228
  translated_text = self.translate_text(text)
229
  sample_rate, audio = self.text_to_speech(translated_text)
230
-
231
  return {
232
  "source_text": text,
233
  "translated_text": translated_text,
@@ -251,168 +213,88 @@ class TranslatorSingleton:
251
  cls._instance = TalklasTranslator()
252
  return cls._instance
253
 
254
- def process_audio(audio_path, source_lang, target_lang):
255
- """Process audio through the full translation pipeline"""
256
- # Validate input
257
- if not audio_path:
258
- return None, "No audio provided", "No translation available", "Please provide audio input"
259
-
260
- # Update languages
261
- source_code = TalklasTranslator.LANGUAGE_MAPPING[source_lang]
262
- target_code = TalklasTranslator.LANGUAGE_MAPPING[target_lang]
263
-
264
- translator = TranslatorSingleton.get_instance()
265
- status = translator.update_languages(source_code, target_code)
266
-
267
- # Process the audio
268
- results = translator.translate_speech(audio_path)
269
-
270
- return results["output_audio"], results["source_text"], results["translated_text"], results["performance"]
271
-
272
- def process_text(text, source_lang, target_lang):
273
- """Process text through the translation pipeline"""
274
- # Validate input
275
- if not text:
276
- return None, "No text provided", "No translation available", "Please provide text input"
277
-
278
- # Update languages
279
- source_code = TalklasTranslator.LANGUAGE_MAPPING[source_lang]
280
- target_code = TalklasTranslator.LANGUAGE_MAPPING[target_lang]
281
-
282
- translator = TranslatorSingleton.get_instance()
283
- status = translator.update_languages(source_code, target_code)
284
-
285
- # Process the text
286
- results = translator.translate_text_only(text)
287
-
288
- return results["output_audio"], results["source_text"], results["translated_text"], results["performance"]
289
-
290
- def create_gradio_interface():
291
- """Create and launch Gradio interface"""
292
- # Define language options
293
- languages = list(TalklasTranslator.LANGUAGE_MAPPING.keys())
294
-
295
- # Define the interface
296
- demo = gr.Blocks(title="Talklas - Speech & Text Translation")
297
-
298
- with demo:
299
- gr.Markdown("# Talklas: Speech-to-Speech Translation System")
300
- gr.Markdown("### Translate between Philippine Languages and English")
301
-
302
- with gr.Row():
303
- with gr.Column():
304
- source_lang = gr.Dropdown(
305
- choices=languages,
306
- value="English",
307
- label="Source Language"
308
- )
309
-
310
- target_lang = gr.Dropdown(
311
- choices=languages,
312
- value="Tagalog",
313
- label="Target Language"
314
- )
315
-
316
- language_status = gr.Textbox(label="Language Status")
317
- update_btn = gr.Button("Update Languages")
318
-
319
- with gr.Tabs():
320
- with gr.TabItem("Audio Input"):
321
- with gr.Row():
322
- with gr.Column():
323
- gr.Markdown("### Audio Input")
324
- audio_input = gr.Audio(
325
- type="filepath",
326
- label="Upload Audio File"
327
- )
328
- audio_translate_btn = gr.Button("Translate Audio", variant="primary")
329
-
330
- with gr.Column():
331
- gr.Markdown("### Output")
332
- audio_output = gr.Audio(
333
- label="Translated Speech",
334
- type="numpy",
335
- autoplay=True
336
- )
337
-
338
- with gr.TabItem("Text Input"):
339
- with gr.Row():
340
- with gr.Column():
341
- gr.Markdown("### Text Input")
342
- text_input = gr.Textbox(
343
- label="Enter text to translate",
344
- lines=3
345
- )
346
- text_translate_btn = gr.Button("Translate Text", variant="primary")
347
-
348
- with gr.Column():
349
- gr.Markdown("### Output")
350
- text_output = gr.Audio(
351
- label="Translated Speech",
352
- type="numpy",
353
- autoplay=True
354
- )
355
-
356
- with gr.Row():
357
- with gr.Column():
358
- source_text = gr.Textbox(label="Source Text")
359
- translated_text = gr.Textbox(label="Translated Text")
360
- performance_info = gr.Textbox(label="Performance Metrics")
361
-
362
- # Set up events
363
- update_btn.click(
364
- lambda source_lang, target_lang: TranslatorSingleton.get_instance().update_languages(
365
- TalklasTranslator.LANGUAGE_MAPPING[source_lang],
366
- TalklasTranslator.LANGUAGE_MAPPING[target_lang]
367
- ),
368
- inputs=[source_lang, target_lang],
369
- outputs=[language_status]
370
- )
371
-
372
- # Audio translate button click
373
- audio_translate_btn.click(
374
- process_audio,
375
- inputs=[audio_input, source_lang, target_lang],
376
- outputs=[audio_output, source_text, translated_text, performance_info]
377
- ).then(
378
- None,
379
- None,
380
- None,
381
- js="""() => {
382
- const audioElements = document.querySelectorAll('audio');
383
- if (audioElements.length > 0) {
384
- const lastAudio = audioElements[audioElements.length - 1];
385
- lastAudio.play().catch(error => {
386
- console.warn('Autoplay failed:', error);
387
- alert('Audio may require user interaction to play');
388
- });
389
- }
390
- }"""
391
- )
392
-
393
- # Text translate button click
394
- text_translate_btn.click(
395
- process_text,
396
- inputs=[text_input, source_lang, target_lang],
397
- outputs=[text_output, source_text, translated_text, performance_info]
398
- ).then(
399
- None,
400
- None,
401
- None,
402
- js="""() => {
403
- const audioElements = document.querySelectorAll('audio');
404
- if (audioElements.length > 0) {
405
- const lastAudio = audioElements[audioElements.length - 1];
406
- lastAudio.play().catch(error => {
407
- console.warn('Autoplay failed:', error);
408
- alert('Audio may require user interaction to play');
409
- });
410
- }
411
- }"""
412
- )
413
-
414
- return demo
415
 
416
  if __name__ == "__main__":
417
- demo = create_gradio_interface()
418
- demo.launch(share=True, debug=True)
 
1
  import os
2
  import torch
 
3
  import numpy as np
4
  import soundfile as sf
5
+ from fastapi import FastAPI, File, UploadFile, HTTPException
6
+ from fastapi.responses import JSONResponse
7
+ from pydantic import BaseModel
8
  from transformers import (
9
  AutoModelForSeq2SeqLM,
10
  AutoTokenizer,
 
15
  WhisperForConditionalGeneration
16
  )
17
  from typing import Optional, Tuple, Dict, List
18
+ import base64
19
+ import io
20
 
21
+ # Your existing TalklasTranslator class (unchanged)
22
  class TalklasTranslator:
 
 
 
 
 
23
  LANGUAGE_MAPPING = {
24
  "English": "eng",
25
  "Tagalog": "tgl",
 
50
  self.sample_rate = 16000
51
 
52
  print(f"Initializing Talklas Translator on {self.device}")
 
 
53
  self._initialize_stt_model()
54
  self._initialize_mt_model()
55
  self._initialize_tts_model()
56
 
57
  def _initialize_stt_model(self):
 
58
  try:
59
  print("Loading STT model...")
60
  try:
 
61
  self.stt_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all")
62
  self.stt_model = AutoModelForCTC.from_pretrained("facebook/mms-1b-all")
 
 
63
  if self.source_lang in self.stt_processor.tokenizer.vocab.keys():
64
  self.stt_processor.tokenizer.set_target_lang(self.source_lang)
65
  self.stt_model.load_adapter(self.source_lang)
66
  print(f"Loaded MMS STT model for {self.source_lang}")
67
  else:
68
  print(f"Language {self.source_lang} not in MMS, using default")
 
69
  except Exception as mms_error:
70
  print(f"MMS loading failed: {mms_error}")
 
71
  print("Loading Whisper as fallback...")
72
  self.stt_processor = WhisperProcessor.from_pretrained("openai/whisper-small")
73
  self.stt_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
74
  print("Loaded Whisper STT model")
 
75
  self.stt_model.to(self.device)
 
76
  except Exception as e:
77
  print(f"STT model initialization failed: {e}")
78
  raise RuntimeError("Could not initialize STT model")
79
 
80
  def _initialize_mt_model(self):
 
81
  try:
82
  print("Loading NLLB Translation model...")
83
  self.mt_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
 
89
  raise
90
 
91
  def _initialize_tts_model(self):
 
92
  try:
93
  print("Loading TTS model...")
94
  try:
 
100
  print("Falling back to English TTS")
101
  self.tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
102
  self.tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
 
103
  self.tts_model.to(self.device)
104
  except Exception as e:
105
  print(f"TTS model initialization failed: {e}")
106
  raise
107
 
108
  def update_languages(self, source_lang: str, target_lang: str) -> str:
 
109
  if source_lang == self.source_lang and target_lang == self.target_lang:
110
  return "Languages already set"
 
111
  self.source_lang = source_lang
112
  self.target_lang = target_lang
 
 
113
  self._initialize_stt_model()
114
  self._initialize_tts_model()
 
115
  return f"Languages updated to {source_lang} → {target_lang}"
116
 
117
  def speech_to_text(self, audio_path: str) -> str:
 
118
  try:
119
  waveform, sample_rate = sf.read(audio_path)
 
120
  if sample_rate != 16000:
121
  import librosa
122
  waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000)
 
123
  inputs = self.stt_processor(
124
  waveform,
125
  sampling_rate=16000,
126
  return_tensors="pt"
127
  ).to(self.device)
 
128
  with torch.no_grad():
129
+ if isinstance(self.stt_model, WhisperForConditionalGeneration):
130
  generated_ids = self.stt_model.generate(**inputs)
131
  transcription = self.stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
132
+ else:
133
  logits = self.stt_model(**inputs).logits
134
  predicted_ids = torch.argmax(logits, dim=-1)
135
  transcription = self.stt_processor.batch_decode(predicted_ids)[0]
 
136
  return transcription
 
137
  except Exception as e:
138
  print(f"Speech recognition failed: {e}")
139
  raise RuntimeError("Speech recognition failed")
140
 
141
  def translate_text(self, text: str) -> str:
 
142
  try:
143
  source_code = self.NLLB_LANGUAGE_CODES[self.source_lang]
144
  target_code = self.NLLB_LANGUAGE_CODES[self.target_lang]
 
145
  self.mt_tokenizer.src_lang = source_code
146
  inputs = self.mt_tokenizer(text, return_tensors="pt").to(self.device)
 
147
  with torch.no_grad():
148
  generated_tokens = self.mt_model.generate(
149
  **inputs,
150
  forced_bos_token_id=self.mt_tokenizer.convert_tokens_to_ids(target_code),
151
  max_length=448
152
  )
 
153
  return self.mt_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
 
154
  except Exception as e:
155
  print(f"Translation failed: {e}")
156
  raise RuntimeError("Text translation failed")
157
 
158
  def text_to_speech(self, text: str) -> Tuple[int, np.ndarray]:
 
159
  try:
160
  inputs = self.tts_tokenizer(text, return_tensors="pt").to(self.device)
 
161
  with torch.no_grad():
162
  output = self.tts_model(**inputs)
 
163
  speech = output.waveform.cpu().numpy().squeeze()
164
  speech = (speech * 32767).astype(np.int16)
 
165
  return self.tts_model.config.sampling_rate, speech
 
166
  except Exception as e:
167
  print(f"Speech synthesis failed: {e}")
168
  raise RuntimeError("Speech synthesis failed")
169
 
170
  def translate_speech(self, audio_path: str) -> Dict:
 
171
  try:
172
  source_text = self.speech_to_text(audio_path)
173
  translated_text = self.translate_text(source_text)
174
  sample_rate, audio = self.text_to_speech(translated_text)
 
175
  return {
176
  "source_text": source_text,
177
  "translated_text": translated_text,
 
187
  }
188
 
189
  def translate_text_only(self, text: str) -> Dict:
 
190
  try:
191
  translated_text = self.translate_text(text)
192
  sample_rate, audio = self.text_to_speech(translated_text)
 
193
  return {
194
  "source_text": text,
195
  "translated_text": translated_text,
 
213
  cls._instance = TalklasTranslator()
214
  return cls._instance
215
 
216
+ # FastAPI application
217
+ app = FastAPI(title="Talklas API", description="Speech-to-Speech Translation API")
218
+
219
+ class TranslationRequest(BaseModel):
220
+ source_lang: str
221
+ target_lang: str
222
+ text: Optional[str] = None
223
+
224
+ @app.post("/translate/audio")
225
+ async def translate_audio(file: UploadFile = File(...), source_lang: str = "English", target_lang: str = "Tagalog"):
226
+ try:
227
+ # Validate languages
228
+ if source_lang not in TalklasTranslator.LANGUAGE_MAPPING or target_lang not in TalklasTranslator.LANGUAGE_MAPPING:
229
+ raise HTTPException(status_code=400, detail="Invalid language selection")
230
+
231
+ # Save uploaded audio file temporarily
232
+ audio_path = f"temp_{file.filename}"
233
+ with open(audio_path, "wb") as f:
234
+ f.write(await file.read())
235
+
236
+ # Update languages
237
+ source_code = TalklasTranslator.LANGUAGE_MAPPING[source_lang]
238
+ target_code = TalklasTranslator.LANGUAGE_MAPPING[target_lang]
239
+ translator = TranslatorSingleton.get_instance()
240
+ translator.update_languages(source_code, target_code)
241
+
242
+ # Process the audio
243
+ results = translator.translate_speech(audio_path)
244
+
245
+ # Clean up temporary file
246
+ os.remove(audio_path)
247
+
248
+ # Convert audio to base64 for response
249
+ sample_rate, audio = results["output_audio"]
250
+ buffer = io.BytesIO()
251
+ sf.write(buffer, audio, sample_rate, format="wav")
252
+ audio_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
253
+
254
+ return JSONResponse(content={
255
+ "source_text": results["source_text"],
256
+ "translated_text": results["translated_text"],
257
+ "audio_base64": audio_base64,
258
+ "sample_rate": sample_rate,
259
+ "performance": results["performance"]
260
+ })
261
+ except Exception as e:
262
+ raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}")
263
+
264
+ @app.post("/translate/text")
265
+ async def translate_text(request: TranslationRequest):
266
+ try:
267
+ # Validate input
268
+ if not request.text:
269
+ raise HTTPException(status_code=400, detail="Text input is required")
270
+ if request.source_lang not in TalklasTranslator.LANGUAGE_MAPPING or request.target_lang not in TalklasTranslator.LANGUAGE_MAPPING:
271
+ raise HTTPException(status_code=400, detail="Invalid language selection")
272
+
273
+ # Update languages
274
+ source_code = TalklasTranslator.LANGUAGE_MAPPING[request.source_lang]
275
+ target_code = TalklasTranslator.LANGUAGE_MAPPING[request.target_lang]
276
+ translator = TranslatorSingleton.get_instance()
277
+ translator.update_languages(source_code, target_code)
278
+
279
+ # Process the text
280
+ results = translator.translate_text_only(request.text)
281
+
282
+ # Convert audio to base64 for response
283
+ sample_rate, audio = results["output_audio"]
284
+ buffer = io.BytesIO()
285
+ sf.write(buffer, audio, sample_rate, format="wav")
286
+ audio_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
287
+
288
+ return JSONResponse(content={
289
+ "source_text": results["source_text"],
290
+ "translated_text": results["translated_text"],
291
+ "audio_base64": audio_base64,
292
+ "sample_rate": sample_rate,
293
+ "performance": results["performance"]
294
+ })
295
+ except Exception as e:
296
+ raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
 
298
  if __name__ == "__main__":
299
+ import uvicorn
300
+ uvicorn.run(app, host="0.0.0.0", port=8000)