Spaces:
Paused
Paused
Fix TTS language issue by dynamically updating model in /translate-audio endpoint
Browse files
app.py
CHANGED
@@ -39,6 +39,7 @@ model_status = {
|
|
39 |
"tts": "not_loaded"
|
40 |
}
|
41 |
error_message = None
|
|
|
42 |
|
43 |
# Model instances
|
44 |
stt_processor = None
|
@@ -179,7 +180,7 @@ def load_models_task():
|
|
179 |
error_message = f"MT model loading failed: {str(e)}"
|
180 |
return
|
181 |
|
182 |
-
# Load TTS model (default to Tagalog, will be updated
|
183 |
logger.info("Starting to load TTS model...")
|
184 |
from transformers import VitsModel, AutoTokenizer
|
185 |
|
@@ -201,6 +202,7 @@ def load_models_task():
|
|
201 |
tts_model.to(device)
|
202 |
logger.info("Fallback TTS model loaded successfully")
|
203 |
model_status["tts"] = "loaded (fallback)"
|
|
|
204 |
except Exception as e2:
|
205 |
logger.error(f"Failed to load fallback TTS model: {str(e2)}")
|
206 |
model_status["tts"] = "failed"
|
@@ -259,7 +261,7 @@ async def health_check():
|
|
259 |
|
260 |
@app.post("/update-languages")
|
261 |
async def update_languages(source_lang: str = Form(...), target_lang: str = Form(...)):
|
262 |
-
global stt_processor, stt_model, tts_model, tts_tokenizer
|
263 |
|
264 |
if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
|
265 |
raise HTTPException(status_code=400, detail="Invalid language selected")
|
@@ -314,6 +316,7 @@ async def update_languages(source_lang: str = Form(...), target_lang: str = Form
|
|
314 |
tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{target_code}")
|
315 |
tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{target_code}")
|
316 |
tts_model.to(device)
|
|
|
317 |
logger.info(f"TTS model updated to {target_code}")
|
318 |
model_status["tts"] = "loaded"
|
319 |
except Exception as e:
|
@@ -323,6 +326,7 @@ async def update_languages(source_lang: str = Form(...), target_lang: str = Form
|
|
323 |
tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
|
324 |
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
|
325 |
tts_model.to(device)
|
|
|
326 |
logger.info("Fallback TTS model loaded successfully")
|
327 |
model_status["tts"] = "loaded (fallback)"
|
328 |
except Exception as e2:
|
@@ -337,7 +341,7 @@ async def update_languages(source_lang: str = Form(...), target_lang: str = Form
|
|
337 |
@app.post("/translate-text")
|
338 |
async def translate_text(text: str = Form(...), source_lang: str = Form(...), target_lang: str = Form(...)):
|
339 |
"""Endpoint to translate text and convert to speech"""
|
340 |
-
global mt_model, mt_tokenizer, tts_model, tts_tokenizer
|
341 |
|
342 |
if not text:
|
343 |
raise HTTPException(status_code=400, detail="No text provided")
|
@@ -373,6 +377,31 @@ async def translate_text(text: str = Form(...), source_lang: str = Form(...), ta
|
|
373 |
else:
|
374 |
logger.warning("MT model not loaded, skipping translation")
|
375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
# Convert translated text to speech
|
377 |
output_audio_url = None
|
378 |
if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
|
@@ -409,7 +438,7 @@ async def translate_text(text: str = Form(...), source_lang: str = Form(...), ta
|
|
409 |
@app.post("/translate-audio")
|
410 |
async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form(...), target_lang: str = Form(...)):
|
411 |
"""Endpoint to transcribe, translate, and convert audio to speech"""
|
412 |
-
global stt_processor, stt_model, mt_model, mt_tokenizer, tts_model, tts_tokenizer
|
413 |
|
414 |
if not audio:
|
415 |
raise HTTPException(status_code=400, detail="No audio file provided")
|
@@ -506,7 +535,32 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
|
|
506 |
else:
|
507 |
logger.warning("MT model not loaded, skipping translation")
|
508 |
|
509 |
-
# Step 5:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
510 |
if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
|
511 |
try:
|
512 |
inputs = tts_tokenizer(translated_text, return_tensors="pt").to(device)
|
|
|
39 |
"tts": "not_loaded"
|
40 |
}
|
41 |
error_message = None
|
42 |
+
current_tts_language = "tgl" # Track the current TTS language
|
43 |
|
44 |
# Model instances
|
45 |
stt_processor = None
|
|
|
180 |
error_message = f"MT model loading failed: {str(e)}"
|
181 |
return
|
182 |
|
183 |
+
# Load TTS model (default to Tagalog, will be updated dynamically)
|
184 |
logger.info("Starting to load TTS model...")
|
185 |
from transformers import VitsModel, AutoTokenizer
|
186 |
|
|
|
202 |
tts_model.to(device)
|
203 |
logger.info("Fallback TTS model loaded successfully")
|
204 |
model_status["tts"] = "loaded (fallback)"
|
205 |
+
current_tts_language = "eng"
|
206 |
except Exception as e2:
|
207 |
logger.error(f"Failed to load fallback TTS model: {str(e2)}")
|
208 |
model_status["tts"] = "failed"
|
|
|
261 |
|
262 |
@app.post("/update-languages")
|
263 |
async def update_languages(source_lang: str = Form(...), target_lang: str = Form(...)):
|
264 |
+
global stt_processor, stt_model, tts_model, tts_tokenizer, current_tts_language
|
265 |
|
266 |
if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
|
267 |
raise HTTPException(status_code=400, detail="Invalid language selected")
|
|
|
316 |
tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{target_code}")
|
317 |
tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{target_code}")
|
318 |
tts_model.to(device)
|
319 |
+
current_tts_language = target_code
|
320 |
logger.info(f"TTS model updated to {target_code}")
|
321 |
model_status["tts"] = "loaded"
|
322 |
except Exception as e:
|
|
|
326 |
tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
|
327 |
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
|
328 |
tts_model.to(device)
|
329 |
+
current_tts_language = "eng"
|
330 |
logger.info("Fallback TTS model loaded successfully")
|
331 |
model_status["tts"] = "loaded (fallback)"
|
332 |
except Exception as e2:
|
|
|
341 |
@app.post("/translate-text")
|
342 |
async def translate_text(text: str = Form(...), source_lang: str = Form(...), target_lang: str = Form(...)):
|
343 |
"""Endpoint to translate text and convert to speech"""
|
344 |
+
global mt_model, mt_tokenizer, tts_model, tts_tokenizer, current_tts_language
|
345 |
|
346 |
if not text:
|
347 |
raise HTTPException(status_code=400, detail="No text provided")
|
|
|
377 |
else:
|
378 |
logger.warning("MT model not loaded, skipping translation")
|
379 |
|
380 |
+
# Update TTS model if the target language doesn't match the current TTS language
|
381 |
+
if current_tts_language != target_code:
|
382 |
+
try:
|
383 |
+
logger.info(f"Updating TTS model for {target_code}...")
|
384 |
+
from transformers import VitsModel, AutoTokenizer
|
385 |
+
tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{target_code}")
|
386 |
+
tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{target_code}")
|
387 |
+
tts_model.to(device)
|
388 |
+
current_tts_language = target_code
|
389 |
+
logger.info(f"TTS model updated to {target_code}")
|
390 |
+
model_status["tts"] = "loaded"
|
391 |
+
except Exception as e:
|
392 |
+
logger.error(f"Failed to load TTS model for {target_code}: {str(e)}")
|
393 |
+
try:
|
394 |
+
logger.info("Falling back to MMS-TTS English model...")
|
395 |
+
tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
|
396 |
+
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
|
397 |
+
tts_model.to(device)
|
398 |
+
current_tts_language = "eng"
|
399 |
+
logger.info("Fallback TTS model loaded successfully")
|
400 |
+
model_status["tts"] = "loaded (fallback)"
|
401 |
+
except Exception as e2:
|
402 |
+
logger.error(f"Failed to load fallback TTS model: {str(e2)}")
|
403 |
+
model_status["tts"] = "failed"
|
404 |
+
|
405 |
# Convert translated text to speech
|
406 |
output_audio_url = None
|
407 |
if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
|
|
|
438 |
@app.post("/translate-audio")
|
439 |
async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form(...), target_lang: str = Form(...)):
|
440 |
"""Endpoint to transcribe, translate, and convert audio to speech"""
|
441 |
+
global stt_processor, stt_model, mt_model, mt_tokenizer, tts_model, tts_tokenizer, current_tts_language
|
442 |
|
443 |
if not audio:
|
444 |
raise HTTPException(status_code=400, detail="No audio file provided")
|
|
|
535 |
else:
|
536 |
logger.warning("MT model not loaded, skipping translation")
|
537 |
|
538 |
+
# Step 5: Update TTS model if the target language doesn't match the current TTS language
|
539 |
+
if current_tts_language != target_code:
|
540 |
+
try:
|
541 |
+
logger.info(f"Updating TTS model for {target_code}...")
|
542 |
+
from transformers import VitsModel, AutoTokenizer
|
543 |
+
tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{target_code}")
|
544 |
+
tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{target_code}")
|
545 |
+
tts_model.to(device)
|
546 |
+
current_tts_language = target_code
|
547 |
+
logger.info(f"TTS model updated to {target_code}")
|
548 |
+
model_status["tts"] = "loaded"
|
549 |
+
except Exception as e:
|
550 |
+
logger.error(f"Failed to load TTS model for {target_code}: {str(e)}")
|
551 |
+
try:
|
552 |
+
logger.info("Falling back to MMS-TTS English model...")
|
553 |
+
tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
|
554 |
+
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
|
555 |
+
tts_model.to(device)
|
556 |
+
current_tts_language = "eng"
|
557 |
+
logger.info("Fallback TTS model loaded successfully")
|
558 |
+
model_status["tts"] = "loaded (fallback)"
|
559 |
+
except Exception as e2:
|
560 |
+
logger.error(f"Failed to load fallback TTS model: {str(e2)}")
|
561 |
+
model_status["tts"] = "failed"
|
562 |
+
|
563 |
+
# Step 6: Convert translated text to speech (TTS)
|
564 |
if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
|
565 |
try:
|
566 |
inputs = tts_tokenizer(translated_text, return_tensors="pt").to(device)
|