Jerich commited on
Commit
eb35d5b
·
verified ·
1 Parent(s): e45bb49

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +161 -84
app.py CHANGED
@@ -35,16 +35,20 @@ models_loaded = False
35
  loading_in_progress = False
36
  loading_thread = None
37
  model_status = {
38
- "stt": "not_loaded",
 
39
  "mt": "not_loaded",
40
  "tts": "not_loaded"
41
  }
42
  error_message = None
43
  current_tts_language = "tgl" # Track the current TTS language
 
44
 
45
  # Model instances
46
- stt_processor = None
47
- stt_model = None
 
 
48
  mt_model = None
49
  mt_tokenizer = None
50
  tts_model = None
@@ -152,38 +156,44 @@ def schedule_cleanup():
152
  # Function to load models in background
153
  def load_models_task():
154
  global models_loaded, loading_in_progress, model_status, error_message
155
- global stt_processor, stt_model, mt_model, mt_tokenizer, tts_model, tts_tokenizer
 
156
 
157
  try:
158
  loading_in_progress = True
 
159
 
160
- # Load STT model (MMS with fallback to Whisper)
161
- logger.info("Starting to load STT model...")
162
  from transformers import AutoProcessor, AutoModelForCTC, WhisperProcessor, WhisperForConditionalGeneration
163
 
 
164
  try:
165
  logger.info("Loading MMS STT model...")
166
- model_status["stt"] = "loading"
167
- stt_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all")
168
- stt_model = AutoModelForCTC.from_pretrained("facebook/mms-1b-all")
169
- device = "cuda" if torch.cuda.is_available() else "cpu"
170
- stt_model.to(device)
171
  logger.info("MMS STT model loaded successfully")
172
- model_status["stt"] = "loaded_mms"
173
  except Exception as mms_error:
174
  logger.error(f"Failed to load MMS STT model: {str(mms_error)}")
175
- logger.info("Falling back to Whisper STT model...")
176
- try:
177
- stt_processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
178
- stt_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
179
- stt_model.to(device)
180
- logger.info("Whisper STT model loaded successfully as fallback")
181
- model_status["stt"] = "loaded_whisper"
182
- except Exception as whisper_error:
183
- logger.error(f"Failed to load Whisper STT model: {str(whisper_error)}")
184
- model_status["stt"] = "failed"
185
- error_message = f"STT model loading failed: MMS error: {str(mms_error)}, Whisper error: {str(whisper_error)}"
186
- return
 
 
 
 
187
 
188
  # Load MT model
189
  logger.info("Starting to load MT model...")
@@ -203,7 +213,7 @@ def load_models_task():
203
  error_message = f"MT model loading failed: {str(e)}"
204
  return
205
 
206
- # Load TTS model (default to Tagalog, will be updated dynamically)
207
  logger.info("Starting to load TTS model...")
208
  from transformers import VitsModel, AutoTokenizer
209
 
@@ -217,22 +227,25 @@ def load_models_task():
217
  model_status["tts"] = "loaded"
218
  except Exception as e:
219
  logger.error(f"Failed to load TTS model for Tagalog: {str(e)}")
220
- # Fallback to English TTS if the target language fails
221
  try:
222
  logger.info("Falling back to MMS-TTS English model...")
223
  tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
224
  tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
225
  tts_model.to(device)
 
226
  logger.info("Fallback TTS model loaded successfully")
227
  model_status["tts"] = "loaded (fallback)"
228
- current_tts_language = "eng"
229
  except Exception as e2:
230
  logger.error(f"Failed to load fallback TTS model: {str(e2)}")
231
  model_status["tts"] = "failed"
232
  error_message = f"TTS model loading failed: {str(e)} (fallback also failed: {str(e2)})"
233
  return
234
 
235
- models_loaded = True
 
 
 
 
236
  logger.info("Model loading completed successfully")
237
 
238
  except Exception as e:
@@ -284,7 +297,8 @@ async def health_check():
284
 
285
  @app.post("/update-languages")
286
  async def update_languages(source_lang: str = Form(...), target_lang: str = Form(...)):
287
- global stt_processor, stt_model, tts_model, tts_tokenizer, current_tts_language
 
288
 
289
  if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
290
  raise HTTPException(status_code=400, detail="Invalid language selected")
@@ -292,43 +306,78 @@ async def update_languages(source_lang: str = Form(...), target_lang: str = Form
292
  source_code = LANGUAGE_MAPPING[source_lang]
293
  target_code = LANGUAGE_MAPPING[target_lang]
294
 
295
- # Update the STT model based on the source language (MMS or Whisper)
296
  try:
297
- logger.info("Updating STT model for source language...")
298
  from transformers import AutoProcessor, AutoModelForCTC, WhisperProcessor, WhisperForConditionalGeneration
299
  device = "cuda" if torch.cuda.is_available() else "cpu"
300
 
301
- try:
302
- logger.info(f"Loading MMS STT model for {source_code}...")
303
- stt_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all")
304
- stt_model = AutoModelForCTC.from_pretrained("facebook/mms-1b-all")
305
- stt_model.to(device)
306
- # Set the target language for MMS
307
- if source_code in stt_processor.tokenizer.vocab.keys():
308
- stt_processor.tokenizer.set_target_lang(source_code)
309
- stt_model.load_adapter(source_code)
310
- logger.info(f"MMS STT model updated to {source_code}")
311
- model_status["stt"] = "loaded_mms"
312
- else:
313
- logger.warning(f"Language {source_code} not supported by MMS, using default")
314
- model_status["stt"] = "loaded_mms_default"
315
- except Exception as mms_error:
316
- logger.error(f"Failed to load MMS STT model for {source_code}: {str(mms_error)}")
317
- logger.info("Falling back to Whisper STT model...")
318
  try:
319
- stt_processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
320
- stt_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
321
- stt_model.to(device)
322
- logger.info("Whisper STT model loaded successfully as fallback")
323
- model_status["stt"] = "loaded_whisper"
 
 
 
324
  except Exception as whisper_error:
325
- logger.error(f"Failed to load Whisper STT model: {str(whisper_error)}")
326
- model_status["stt"] = "failed"
327
- error_message = f"STT model update failed: MMS error: {str(mms_error)}, Whisper error: {str(whisper_error)}"
328
- return {"status": "failed", "error": error_message}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  except Exception as e:
330
  logger.error(f"Error updating STT model: {str(e)}")
331
- model_status["stt"] = "failed"
 
332
  error_message = f"STT model update failed: {str(e)}"
333
  return {"status": "failed", "error": error_message}
334
 
@@ -466,7 +515,8 @@ async def translate_text(text: str = Form(...), source_lang: str = Form(...), ta
466
  @app.post("/translate-audio")
467
  async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form(...), target_lang: str = Form(...)):
468
  """Endpoint to transcribe, translate, and convert audio to speech"""
469
- global stt_processor, stt_model, mt_model, mt_tokenizer, tts_model, tts_tokenizer, current_tts_language
 
470
 
471
  if not audio:
472
  raise HTTPException(status_code=400, detail="No audio file provided")
@@ -477,17 +527,37 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
477
  request_id = str(uuid.uuid4())
478
 
479
  # Check if STT model is loaded
480
- if model_status["stt"] not in ["loaded_mms", "loaded_mms_default", "loaded_whisper"] or stt_processor is None or stt_model is None:
481
- logger.warning("STT model not loaded, returning placeholder response")
482
- return {
483
- "request_id": request_id,
484
- "status": "processing",
485
- "message": "STT model not loaded yet. Please try again later.",
486
- "source_text": "Transcription not available",
487
- "translated_text": "Translation not available",
488
- "is_inappropriate": False,
489
- "output_audio": None
490
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
491
 
492
  # Save the uploaded audio to a temporary file
493
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
@@ -526,24 +596,30 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
526
 
527
  # Step 3: Transcribe the audio (STT)
528
  device = "cuda" if torch.cuda.is_available() else "cpu"
529
- logger.info(f"Using device: {device}")
530
- inputs = stt_processor(waveform.numpy(), sampling_rate=16000, return_tensors="pt").to(device)
531
- logger.info("Audio processed, generating transcription...")
532
 
533
- with torch.no_grad():
534
- if model_status["stt"] == "loaded_whisper":
535
- # Whisper model
536
- generated_ids = stt_model.generate(**inputs, language="en")
537
- transcription = stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
538
- else:
539
- # MMS model
540
- logits = stt_model(**inputs).logits
 
 
 
 
 
 
 
 
 
541
  predicted_ids = torch.argmax(logits, dim=-1)
542
- transcription = stt_processor.batch_decode(predicted_ids)[0]
543
  logger.info(f"Transcription completed: {transcription}")
544
 
545
  # Step 4: Translate the transcribed text (MT)
546
- source_code = LANGUAGE_MAPPING[source_lang]
547
  target_code = LANGUAGE_MAPPING[target_lang]
548
 
549
  if model_status["mt"] == "loaded" and mt_model is not None and mt_tokenizer is not None:
@@ -618,7 +694,8 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
618
  except Exception as e:
619
  logger.error(f"Error during TTS conversion: {str(e)}")
620
  output_audio_url = None
621
- return {
 
622
  "request_id": request_id,
623
  "status": "completed",
624
  "message": "Transcription, translation, and TTS completed (or partially completed).",
 
35
  loading_in_progress = False
36
  loading_thread = None
37
  model_status = {
38
+ "stt_mms": "not_loaded",
39
+ "stt_whisper": "not_loaded",
40
  "mt": "not_loaded",
41
  "tts": "not_loaded"
42
  }
43
  error_message = None
44
  current_tts_language = "tgl" # Track the current TTS language
45
+ current_stt_model = None # Track which STT model is active ("mms" or "whisper")
46
 
47
  # Model instances
48
+ stt_mms_processor = None
49
+ stt_mms_model = None
50
+ stt_whisper_processor = None
51
+ stt_whisper_model = None
52
  mt_model = None
53
  mt_tokenizer = None
54
  tts_model = None
 
156
  # Function to load models in background
157
  def load_models_task():
158
  global models_loaded, loading_in_progress, model_status, error_message
159
+ global stt_mms_processor, stt_mms_model, stt_whisper_processor, stt_whisper_model
160
+ global mt_model, mt_tokenizer, tts_model, tts_tokenizer
161
 
162
  try:
163
  loading_in_progress = True
164
+ device = "cuda" if torch.cuda.is_available() else "cpu"
165
 
166
+ # Load STT models (MMS and Whisper Small)
167
+ logger.info("Starting to load STT models...")
168
  from transformers import AutoProcessor, AutoModelForCTC, WhisperProcessor, WhisperForConditionalGeneration
169
 
170
+ # Load MMS STT model
171
  try:
172
  logger.info("Loading MMS STT model...")
173
+ model_status["stt_mms"] = "loading"
174
+ stt_mms_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all")
175
+ stt_mms_model = AutoModelForCTC.from_pretrained("facebook/mms-1b-all")
176
+ stt_mms_model.to(device)
 
177
  logger.info("MMS STT model loaded successfully")
178
+ model_status["stt_mms"] = "loaded"
179
  except Exception as mms_error:
180
  logger.error(f"Failed to load MMS STT model: {str(mms_error)}")
181
+ model_status["stt_mms"] = "failed"
182
+ error_message = f"MMS STT model loading failed: {str(mms_error)}"
183
+
184
+ # Load Whisper Small STT model
185
+ try:
186
+ logger.info("Loading Whisper Small STT model...")
187
+ model_status["stt_whisper"] = "loading"
188
+ stt_whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-small")
189
+ stt_whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
190
+ stt_whisper_model.to(device)
191
+ logger.info("Whisper Small STT model loaded successfully")
192
+ model_status["stt_whisper"] = "loaded"
193
+ except Exception as whisper_error:
194
+ logger.error(f"Failed to load Whisper Small STT model: {str(whisper_error)}")
195
+ model_status["stt_whisper"] = "failed"
196
+ error_message = f"Whisper Small STT model loading failed: {str(whisper_error)}"
197
 
198
  # Load MT model
199
  logger.info("Starting to load MT model...")
 
213
  error_message = f"MT model loading failed: {str(e)}"
214
  return
215
 
216
+ # Load TTS model (default to Tagalog)
217
  logger.info("Starting to load TTS model...")
218
  from transformers import VitsModel, AutoTokenizer
219
 
 
227
  model_status["tts"] = "loaded"
228
  except Exception as e:
229
  logger.error(f"Failed to load TTS model for Tagalog: {str(e)}")
 
230
  try:
231
  logger.info("Falling back to MMS-TTS English model...")
232
  tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
233
  tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
234
  tts_model.to(device)
235
+ current_tts_language = "eng"
236
  logger.info("Fallback TTS model loaded successfully")
237
  model_status["tts"] = "loaded (fallback)"
 
238
  except Exception as e2:
239
  logger.error(f"Failed to load fallback TTS model: {str(e2)}")
240
  model_status["tts"] = "failed"
241
  error_message = f"TTS model loading failed: {str(e)} (fallback also failed: {str(e2)})"
242
  return
243
 
244
+ # Check if critical models are loaded
245
+ stt_loaded = model_status["stt_mms"] == "loaded" or model_status["stt_whisper"] == "loaded"
246
+ mt_loaded = model_status["mt"] == "loaded"
247
+ tts_loaded = model_status["tts"].startswith("loaded")
248
+ models_loaded = stt_loaded and mt_loaded and tts_loaded
249
  logger.info("Model loading completed successfully")
250
 
251
  except Exception as e:
 
297
 
298
  @app.post("/update-languages")
299
  async def update_languages(source_lang: str = Form(...), target_lang: str = Form(...)):
300
+ global stt_mms_processor, stt_mms_model, stt_whisper_processor, stt_whisper_model
301
+ global tts_model, tts_tokenizer, current_tts_language, current_stt_model
302
 
303
  if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
304
  raise HTTPException(status_code=400, detail="Invalid language selected")
 
306
  source_code = LANGUAGE_MAPPING[source_lang]
307
  target_code = LANGUAGE_MAPPING[target_lang]
308
 
309
+ # Update the STT model based on the source language
310
  try:
311
+ logger.info(f"Updating STT model for source language {source_code}...")
312
  from transformers import AutoProcessor, AutoModelForCTC, WhisperProcessor, WhisperForConditionalGeneration
313
  device = "cuda" if torch.cuda.is_available() else "cpu"
314
 
315
+ # Use Whisper Small for English or Tagalog, MMS for others
316
+ if source_code in ["eng", "tgl"]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
  try:
318
+ logger.info(f"Loading Whisper Small STT model for {source_code}...")
319
+ if model_status["stt_whisper"] != "loaded":
320
+ stt_whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-small")
321
+ stt_whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
322
+ stt_whisper_model.to(device)
323
+ model_status["stt_whisper"] = "loaded"
324
+ current_stt_model = "whisper"
325
+ logger.info("Whisper Small STT model selected")
326
  except Exception as whisper_error:
327
+ logger.error(f"Failed to load Whisper Small STT model: {str(whisper_error)}")
328
+ try:
329
+ logger.info(f"Falling back to MMS STT model for {source_code}...")
330
+ if model_status["stt_mms"] != "loaded":
331
+ stt_mms_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all")
332
+ stt_mms_model = AutoModelForCTC.from_pretrained("facebook/mms-1b-all")
333
+ stt_mms_model.to(device)
334
+ model_status["stt_mms"] = "loaded"
335
+ if source_code in stt_mms_processor.tokenizer.vocab.keys():
336
+ stt_mms_processor.tokenizer.set_target_lang(source_code)
337
+ stt_mms_model.load_adapter(source_code)
338
+ current_stt_model = "mms"
339
+ logger.info("MMS STT model selected as fallback")
340
+ except Exception as mms_error:
341
+ logger.error(f"Failed to load MMS STT model: {str(mms_error)}")
342
+ model_status["stt_mms"] = "failed"
343
+ model_status["stt_whisper"] = "failed"
344
+ error_message = f"STT model update failed: Whisper error: {str(whisper_error)}, MMS error: {str(mms_error)}"
345
+ return {"status": "failed", "error": error_message}
346
+ else:
347
+ try:
348
+ logger.info(f"Loading MMS STT model for {source_code}...")
349
+ if model_status["stt_mms"] != "loaded":
350
+ stt_mms_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all")
351
+ stt_mms_model = AutoModelForCTC.from_pretrained("facebook/mms-1b-all")
352
+ stt_mms_model.to(device)
353
+ model_status["stt_mms"] = "loaded"
354
+ if source_code in stt_mms_processor.tokenizer.vocab.keys():
355
+ stt_mms_processor.tokenizer.set_target_lang(source_code)
356
+ stt_mms_model.load_adapter(source_code)
357
+ current_stt_model = "mms"
358
+ logger.info(f"MMS STT model selected for {source_code}")
359
+ except Exception as mms_error:
360
+ logger.error(f"Failed to load MMS STT model: {str(mms_error)}")
361
+ try:
362
+ logger.info(f"Falling back to Whisper Small STT model for {source_code}...")
363
+ if model_status["stt_whisper"] != "loaded":
364
+ stt_whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-small")
365
+ stt_whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
366
+ stt_whisper_model.to(device)
367
+ model_status["stt_whisper"] = "loaded"
368
+ current_stt_model = "whisper"
369
+ logger.info("Whisper Small STT model selected as fallback")
370
+ except Exception as whisper_error:
371
+ logger.error(f"Failed to load Whisper Small STT model: {str(whisper_error)}")
372
+ model_status["stt_mms"] = "failed"
373
+ model_status["stt_whisper"] = "failed"
374
+ error_message = f"STT model update failed: MMS error: {str(mms_error)}, Whisper error: {str(whisper_error)}"
375
+ return {"status": "failed", "error": error_message}
376
+
377
  except Exception as e:
378
  logger.error(f"Error updating STT model: {str(e)}")
379
+ model_status["stt_mms"] = "failed"
380
+ model_status["stt_whisper"] = "failed"
381
  error_message = f"STT model update failed: {str(e)}"
382
  return {"status": "failed", "error": error_message}
383
 
 
515
  @app.post("/translate-audio")
516
  async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form(...), target_lang: str = Form(...)):
517
  """Endpoint to transcribe, translate, and convert audio to speech"""
518
+ global stt_mms_processor, stt_mms_model, stt_whisper_processor, stt_whisper_model
519
+ global mt_model, mt_tokenizer, tts_model, tts_tokenizer, current_tts_language, current_stt_model
520
 
521
  if not audio:
522
  raise HTTPException(status_code=400, detail="No audio file provided")
 
527
  request_id = str(uuid.uuid4())
528
 
529
  # Check if STT model is loaded
530
+ source_code = LANGUAGE_MAPPING[source_lang]
531
+ use_whisper = source_code in ["eng", "tgl"]
532
+
533
+ if use_whisper and (model_status["stt_whisper"] != "loaded" or stt_whisper_processor is None or stt_whisper_model is None):
534
+ logger.warning("Whisper Small STT model not loaded, falling back to MMS")
535
+ if model_status["stt_mms"] != "loaded" or stt_mms_processor is None or stt_mms_model is None:
536
+ logger.warning("MMS STT model not loaded either, returning placeholder response")
537
+ return {
538
+ "request_id": request_id,
539
+ "status": "processing",
540
+ "message": "STT models not loaded yet. Please try again later.",
541
+ "source_text": "Transcription not available",
542
+ "translated_text": "Translation not available",
543
+ "is_inappropriate": False,
544
+ "output_audio": None
545
+ }
546
+ use_whisper = False
547
+ elif not use_whisper and (model_status["stt_mms"] != "loaded" or stt_mms_processor is None or stt_mms_model is None):
548
+ logger.warning("MMS STT model not loaded, falling back to Whisper Small")
549
+ if model_status["stt_whisper"] != "loaded" or stt_whisper_processor is None or stt_whisper_model is None:
550
+ logger.warning("Whisper Small STT model not loaded either, returning placeholder response")
551
+ return {
552
+ "request_id": request_id,
553
+ "status": "processing",
554
+ "message": "STT models not loaded yet. Please try again later.",
555
+ "source_text": "Transcription not available",
556
+ "translated_text": "Translation not available",
557
+ "is_inappropriate": False,
558
+ "output_audio": None
559
+ }
560
+ use_whisper = True
561
 
562
  # Save the uploaded audio to a temporary file
563
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
 
596
 
597
  # Step 3: Transcribe the audio (STT)
598
  device = "cuda" if torch.cuda.is_available() else "cpu"
599
+ logger.info(f"Using device: {device} with {'Whisper Small' if use_whisper else 'MMS'} model")
 
 
600
 
601
+ if use_whisper:
602
+ processor = stt_whisper_processor
603
+ model = stt_whisper_model
604
+ inputs = processor(waveform.numpy()[0], sampling_rate=16000, return_tensors="pt").to(device)
605
+ with torch.no_grad():
606
+ language = "en" if source_code == "eng" else "tl" if source_code == "tgl" else None
607
+ generated_ids = model.generate(**inputs, language=language, task="transcribe")
608
+ transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
609
+ else:
610
+ processor = stt_mms_processor
611
+ model = stt_mms_model
612
+ if source_code in processor.tokenizer.vocab.keys():
613
+ processor.tokenizer.set_target_lang(source_code)
614
+ model.load_adapter(source_code)
615
+ inputs = processor(waveform.numpy(), sampling_rate=16000, return_tensors="pt").to(device)
616
+ with torch.no_grad():
617
+ logits = model(**inputs).logits
618
  predicted_ids = torch.argmax(logits, dim=-1)
619
+ transcription = processor.batch_decode(predicted_ids)[0]
620
  logger.info(f"Transcription completed: {transcription}")
621
 
622
  # Step 4: Translate the transcribed text (MT)
 
623
  target_code = LANGUAGE_MAPPING[target_lang]
624
 
625
  if model_status["mt"] == "loaded" and mt_model is not None and mt_tokenizer is not None:
 
694
  except Exception as e:
695
  logger.error(f"Error during TTS conversion: {str(e)}")
696
  output_audio_url = None
697
+
698
+ return {
699
  "request_id": request_id,
700
  "status": "completed",
701
  "message": "Transcription, translation, and TTS completed (or partially completed).",