Jerich commited on
Commit
28e1a88
·
verified ·
1 Parent(s): 0ef092a

Updated the code to use the Whisper model if the source language is English or Tagalog; otherwise, it will use MMS. Additionally, the link to the synthesized speech has been updated to match the current space.

Browse files
Files changed (1) hide show
  1. app.py +31 -10
app.py CHANGED
@@ -420,7 +420,7 @@ async def translate_text(text: str = Form(...), source_lang: str = Form(...), ta
420
  output_path, error = synthesize_speech(translated_text, target_code)
421
  if output_path:
422
  output_filename = os.path.basename(output_path)
423
- output_audio_url = f"https://jerich-talklasapp.hf.space/audio_output/{output_filename}"
424
  logger.info("TTS conversion completed")
425
  except Exception as e:
426
  logger.error(f"Error during TTS conversion: {str(e)}")
@@ -448,7 +448,7 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
448
  logger.info(f"Translate-audio requested: {audio.filename} from {source_lang} to {target_lang}")
449
  request_id = str(uuid.uuid4())
450
 
451
- # Check if STT model is loaded
452
  if model_status["stt"] not in ["loaded_mms", "loaded_mms_default", "loaded_whisper"] or stt_processor is None or stt_model is None:
453
  logger.warning("STT model not loaded, returning placeholder response")
454
  return {
@@ -499,23 +499,44 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
499
  # Step 3: Transcribe the audio (STT)
500
  device = "cuda" if torch.cuda.is_available() else "cpu"
501
  logger.info(f"Using device: {device}")
 
 
 
 
 
 
 
 
 
502
  inputs = stt_processor(waveform.numpy(), sampling_rate=16000, return_tensors="pt").to(device)
503
  logger.info("Audio processed, generating transcription...")
504
 
505
  with torch.no_grad():
506
- if model_status["stt"] == "loaded_whisper":
507
- # Whisper model
508
- generated_ids = stt_model.generate(**inputs, language="en")
 
509
  transcription = stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
510
- else:
511
- # MMS model
 
512
  logits = stt_model(**inputs).logits
513
  predicted_ids = torch.argmax(logits, dim=-1)
514
  transcription = stt_processor.batch_decode(predicted_ids)[0]
 
 
 
 
 
 
 
 
 
 
 
515
  logger.info(f"Transcription completed: {transcription}")
516
 
517
  # Step 4: Translate the transcribed text (MT)
518
- source_code = LANGUAGE_MAPPING[source_lang]
519
  target_code = LANGUAGE_MAPPING[target_lang]
520
 
521
  if model_status["mt"] == "loaded" and mt_model is not None and mt_tokenizer is not None:
@@ -549,7 +570,7 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
549
  output_path, error = synthesize_speech(translated_text, target_code)
550
  if output_path:
551
  output_filename = os.path.basename(output_path)
552
- output_audio_url = f"https://jerich-talklasapp.hf.space/audio_output/{output_filename}"
553
  logger.info("TTS conversion completed")
554
  except Exception as e:
555
  logger.error(f"Error during TTS conversion: {str(e)}")
@@ -603,7 +624,7 @@ async def text_to_speech(text: str = Form(...), target_lang: str = Form(...)):
603
  output_path, error = synthesize_speech(text, target_code)
604
  if output_path:
605
  output_filename = os.path.basename(output_path)
606
- output_audio_url = f"https://jerich-talklasapp.hf.space/audio_output/{output_filename}"
607
  logger.info("TTS conversion completed")
608
  else:
609
  logger.error(f"TTS conversion failed: {error}")
 
420
  output_path, error = synthesize_speech(translated_text, target_code)
421
  if output_path:
422
  output_filename = os.path.basename(output_path)
423
+ output_audio_url = f"https://jerich-talklasapp2.hf.space/audio_output/{output_filename}"
424
  logger.info("TTS conversion completed")
425
  except Exception as e:
426
  logger.error(f"Error during TTS conversion: {str(e)}")
 
448
  logger.info(f"Translate-audio requested: {audio.filename} from {source_lang} to {target_lang}")
449
  request_id = str(uuid.uuid4())
450
 
451
+ # Check if STT models are loaded
452
  if model_status["stt"] not in ["loaded_mms", "loaded_mms_default", "loaded_whisper"] or stt_processor is None or stt_model is None:
453
  logger.warning("STT model not loaded, returning placeholder response")
454
  return {
 
499
  # Step 3: Transcribe the audio (STT)
500
  device = "cuda" if torch.cuda.is_available() else "cpu"
501
  logger.info(f"Using device: {device}")
502
+
503
+ # Determine which model to use based on source language
504
+ source_code = LANGUAGE_MAPPING[source_lang]
505
+ use_whisper = source_code in ["eng", "tgl"] # Use Whisper for English and Tagalog
506
+ use_mms = not use_whisper # Use MMS for other Philippine languages
507
+
508
+ logger.info(f"Source language: {source_lang} ({source_code}), Using Whisper: {use_whisper}, Using MMS: {use_mms}")
509
+
510
+ # Process with appropriate model
511
  inputs = stt_processor(waveform.numpy(), sampling_rate=16000, return_tensors="pt").to(device)
512
  logger.info("Audio processed, generating transcription...")
513
 
514
  with torch.no_grad():
515
+ if use_whisper and model_status["stt"] == "loaded_whisper":
516
+ # Whisper model for English and Tagalog
517
+ logger.info(f"Using Whisper model for {source_lang}")
518
+ generated_ids = stt_model.generate(**inputs, language="en" if source_code == "eng" else "tl")
519
  transcription = stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
520
+ elif model_status["stt"] in ["loaded_mms", "loaded_mms_default"]:
521
+ # MMS model for other Philippine languages
522
+ logger.info(f"Using MMS model for {source_lang}")
523
  logits = stt_model(**inputs).logits
524
  predicted_ids = torch.argmax(logits, dim=-1)
525
  transcription = stt_processor.batch_decode(predicted_ids)[0]
526
+ else:
527
+ # Fallback to any available model
528
+ logger.info(f"Preferred model not available, using fallback model")
529
+ if model_status["stt"] == "loaded_whisper":
530
+ generated_ids = stt_model.generate(**inputs, language="en")
531
+ transcription = stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
532
+ else:
533
+ logits = stt_model(**inputs).logits
534
+ predicted_ids = torch.argmax(logits, dim=-1)
535
+ transcription = stt_processor.batch_decode(predicted_ids)[0]
536
+
537
  logger.info(f"Transcription completed: {transcription}")
538
 
539
  # Step 4: Translate the transcribed text (MT)
 
540
  target_code = LANGUAGE_MAPPING[target_lang]
541
 
542
  if model_status["mt"] == "loaded" and mt_model is not None and mt_tokenizer is not None:
 
570
  output_path, error = synthesize_speech(translated_text, target_code)
571
  if output_path:
572
  output_filename = os.path.basename(output_path)
573
+ output_audio_url = f"https://jerich-talklasapp2.hf.space/audio_output/{output_filename}"
574
  logger.info("TTS conversion completed")
575
  except Exception as e:
576
  logger.error(f"Error during TTS conversion: {str(e)}")
 
624
  output_path, error = synthesize_speech(text, target_code)
625
  if output_path:
626
  output_filename = os.path.basename(output_path)
627
+ output_audio_url = f"https://jerich-talklasapp2.hf.space/audio_output/{output_filename}"
628
  logger.info("TTS conversion completed")
629
  else:
630
  logger.error(f"TTS conversion failed: {error}")