Jerich commited on
Commit
2b35bda
·
verified ·
1 Parent(s): 49a0a85

Fix TTS language issue by dynamically updating model in /translate-audio endpoint

Browse files
Files changed (1) hide show
  1. app.py +59 -5
app.py CHANGED
@@ -39,6 +39,7 @@ model_status = {
39
  "tts": "not_loaded"
40
  }
41
  error_message = None
 
42
 
43
  # Model instances
44
  stt_processor = None
@@ -179,7 +180,7 @@ def load_models_task():
179
  error_message = f"MT model loading failed: {str(e)}"
180
  return
181
 
182
- # Load TTS model (default to Tagalog, will be updated by /update-languages)
183
  logger.info("Starting to load TTS model...")
184
  from transformers import VitsModel, AutoTokenizer
185
 
@@ -201,6 +202,7 @@ def load_models_task():
201
  tts_model.to(device)
202
  logger.info("Fallback TTS model loaded successfully")
203
  model_status["tts"] = "loaded (fallback)"
 
204
  except Exception as e2:
205
  logger.error(f"Failed to load fallback TTS model: {str(e2)}")
206
  model_status["tts"] = "failed"
@@ -259,7 +261,7 @@ async def health_check():
259
 
260
  @app.post("/update-languages")
261
  async def update_languages(source_lang: str = Form(...), target_lang: str = Form(...)):
262
- global stt_processor, stt_model, tts_model, tts_tokenizer
263
 
264
  if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
265
  raise HTTPException(status_code=400, detail="Invalid language selected")
@@ -314,6 +316,7 @@ async def update_languages(source_lang: str = Form(...), target_lang: str = Form
314
  tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{target_code}")
315
  tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{target_code}")
316
  tts_model.to(device)
 
317
  logger.info(f"TTS model updated to {target_code}")
318
  model_status["tts"] = "loaded"
319
  except Exception as e:
@@ -323,6 +326,7 @@ async def update_languages(source_lang: str = Form(...), target_lang: str = Form
323
  tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
324
  tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
325
  tts_model.to(device)
 
326
  logger.info("Fallback TTS model loaded successfully")
327
  model_status["tts"] = "loaded (fallback)"
328
  except Exception as e2:
@@ -337,7 +341,7 @@ async def update_languages(source_lang: str = Form(...), target_lang: str = Form
337
  @app.post("/translate-text")
338
  async def translate_text(text: str = Form(...), source_lang: str = Form(...), target_lang: str = Form(...)):
339
  """Endpoint to translate text and convert to speech"""
340
- global mt_model, mt_tokenizer, tts_model, tts_tokenizer
341
 
342
  if not text:
343
  raise HTTPException(status_code=400, detail="No text provided")
@@ -373,6 +377,31 @@ async def translate_text(text: str = Form(...), source_lang: str = Form(...), ta
373
  else:
374
  logger.warning("MT model not loaded, skipping translation")
375
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  # Convert translated text to speech
377
  output_audio_url = None
378
  if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
@@ -409,7 +438,7 @@ async def translate_text(text: str = Form(...), source_lang: str = Form(...), ta
409
  @app.post("/translate-audio")
410
  async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form(...), target_lang: str = Form(...)):
411
  """Endpoint to transcribe, translate, and convert audio to speech"""
412
- global stt_processor, stt_model, mt_model, mt_tokenizer, tts_model, tts_tokenizer
413
 
414
  if not audio:
415
  raise HTTPException(status_code=400, detail="No audio file provided")
@@ -506,7 +535,32 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
506
  else:
507
  logger.warning("MT model not loaded, skipping translation")
508
 
509
- # Step 5: Convert translated text to speech (TTS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
510
  if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
511
  try:
512
  inputs = tts_tokenizer(translated_text, return_tensors="pt").to(device)
 
39
  "tts": "not_loaded"
40
  }
41
  error_message = None
42
+ current_tts_language = "tgl" # Track the current TTS language
43
 
44
  # Model instances
45
  stt_processor = None
 
180
  error_message = f"MT model loading failed: {str(e)}"
181
  return
182
 
183
+ # Load TTS model (default to Tagalog, will be updated dynamically)
184
  logger.info("Starting to load TTS model...")
185
  from transformers import VitsModel, AutoTokenizer
186
 
 
202
  tts_model.to(device)
203
  logger.info("Fallback TTS model loaded successfully")
204
  model_status["tts"] = "loaded (fallback)"
205
+ current_tts_language = "eng"
206
  except Exception as e2:
207
  logger.error(f"Failed to load fallback TTS model: {str(e2)}")
208
  model_status["tts"] = "failed"
 
261
 
262
  @app.post("/update-languages")
263
  async def update_languages(source_lang: str = Form(...), target_lang: str = Form(...)):
264
+ global stt_processor, stt_model, tts_model, tts_tokenizer, current_tts_language
265
 
266
  if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
267
  raise HTTPException(status_code=400, detail="Invalid language selected")
 
316
  tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{target_code}")
317
  tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{target_code}")
318
  tts_model.to(device)
319
+ current_tts_language = target_code
320
  logger.info(f"TTS model updated to {target_code}")
321
  model_status["tts"] = "loaded"
322
  except Exception as e:
 
326
  tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
327
  tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
328
  tts_model.to(device)
329
+ current_tts_language = "eng"
330
  logger.info("Fallback TTS model loaded successfully")
331
  model_status["tts"] = "loaded (fallback)"
332
  except Exception as e2:
 
341
  @app.post("/translate-text")
342
  async def translate_text(text: str = Form(...), source_lang: str = Form(...), target_lang: str = Form(...)):
343
  """Endpoint to translate text and convert to speech"""
344
+ global mt_model, mt_tokenizer, tts_model, tts_tokenizer, current_tts_language
345
 
346
  if not text:
347
  raise HTTPException(status_code=400, detail="No text provided")
 
377
  else:
378
  logger.warning("MT model not loaded, skipping translation")
379
 
380
+ # Update TTS model if the target language doesn't match the current TTS language
381
+ if current_tts_language != target_code:
382
+ try:
383
+ logger.info(f"Updating TTS model for {target_code}...")
384
+ from transformers import VitsModel, AutoTokenizer
385
+ tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{target_code}")
386
+ tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{target_code}")
387
+ tts_model.to(device)
388
+ current_tts_language = target_code
389
+ logger.info(f"TTS model updated to {target_code}")
390
+ model_status["tts"] = "loaded"
391
+ except Exception as e:
392
+ logger.error(f"Failed to load TTS model for {target_code}: {str(e)}")
393
+ try:
394
+ logger.info("Falling back to MMS-TTS English model...")
395
+ tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
396
+ tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
397
+ tts_model.to(device)
398
+ current_tts_language = "eng"
399
+ logger.info("Fallback TTS model loaded successfully")
400
+ model_status["tts"] = "loaded (fallback)"
401
+ except Exception as e2:
402
+ logger.error(f"Failed to load fallback TTS model: {str(e2)}")
403
+ model_status["tts"] = "failed"
404
+
405
  # Convert translated text to speech
406
  output_audio_url = None
407
  if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
 
438
  @app.post("/translate-audio")
439
  async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form(...), target_lang: str = Form(...)):
440
  """Endpoint to transcribe, translate, and convert audio to speech"""
441
+ global stt_processor, stt_model, mt_model, mt_tokenizer, tts_model, tts_tokenizer, current_tts_language
442
 
443
  if not audio:
444
  raise HTTPException(status_code=400, detail="No audio file provided")
 
535
  else:
536
  logger.warning("MT model not loaded, skipping translation")
537
 
538
+ # Step 5: Update TTS model if the target language doesn't match the current TTS language
539
+ if current_tts_language != target_code:
540
+ try:
541
+ logger.info(f"Updating TTS model for {target_code}...")
542
+ from transformers import VitsModel, AutoTokenizer
543
+ tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{target_code}")
544
+ tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{target_code}")
545
+ tts_model.to(device)
546
+ current_tts_language = target_code
547
+ logger.info(f"TTS model updated to {target_code}")
548
+ model_status["tts"] = "loaded"
549
+ except Exception as e:
550
+ logger.error(f"Failed to load TTS model for {target_code}: {str(e)}")
551
+ try:
552
+ logger.info("Falling back to MMS-TTS English model...")
553
+ tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
554
+ tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
555
+ tts_model.to(device)
556
+ current_tts_language = "eng"
557
+ logger.info("Fallback TTS model loaded successfully")
558
+ model_status["tts"] = "loaded (fallback)"
559
+ except Exception as e2:
560
+ logger.error(f"Failed to load fallback TTS model: {str(e2)}")
561
+ model_status["tts"] = "failed"
562
+
563
+ # Step 6: Convert translated text to speech (TTS)
564
  if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
565
  try:
566
  inputs = tts_tokenizer(translated_text, return_tensors="pt").to(device)