Jerich commited on
Commit
0ec414d
·
verified ·
1 Parent(s): dec41a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +160 -66
app.py CHANGED
@@ -17,7 +17,7 @@ from fastapi import FastAPI, HTTPException, UploadFile, File, Form, BackgroundTa
17
  from fastapi.responses import JSONResponse
18
  from fastapi.staticfiles import StaticFiles
19
  from typing import Dict, Any, Optional, Tuple, List
20
- from(datetime import datetime, timedelta)
21
 
22
  # Configure logging
23
  logging.basicConfig(level=logging.INFO)
@@ -35,7 +35,8 @@ models_loaded = False
35
  loading_in_progress = False
36
  loading_thread = None
37
  model_status = {
38
- "stt": "not_loaded",
 
39
  "mt": "not_loaded",
40
  "tts": "not_loaded"
41
  }
@@ -43,10 +44,10 @@ error_message = None
43
  current_tts_language = "tgl" # Track the current TTS language
44
 
45
  # Model instances
46
- stt_processor_whisper = None
47
- stt_model_whisper = None
48
- stt_processor_mms = None
49
- stt_model_mms = None
50
  mt_model = None
51
  mt_tokenizer = None
52
  tts_model = None
@@ -62,11 +63,9 @@ LANGUAGE_MAPPING = {
62
  "Pangasinan": "pag"
63
  }
64
 
65
- # Mapping for Whisper language names
66
- WHISPER_LANGUAGE_MAPPING = {
67
- "eng": "english",
68
- "tgl": "tagalog"
69
- }
70
 
71
  NLLB_LANGUAGE_CODES = {
72
  "eng": "eng_Latn",
@@ -93,39 +92,60 @@ def check_inappropriate_content(text: str) -> bool:
93
  Check if the text contains inappropriate content.
94
  Returns True if inappropriate content is detected, False otherwise.
95
  """
 
96
  text_lower = text.lower()
 
 
97
  for word in INAPPROPRIATE_WORDS:
 
98
  pattern = r'\b' + re.escape(word) + r'\b'
99
  if re.search(pattern, text_lower):
100
  logger.warning(f"Inappropriate content detected: {word}")
101
  return True
 
102
  return False
103
 
104
  # Function to save PCM data as a WAV file
105
  def save_pcm_to_wav(pcm_data: list, sample_rate: int, output_path: str):
 
106
  pcm_array = np.array(pcm_data, dtype=np.int16)
 
107
  with wave.open(output_path, 'wb') as wav_file:
 
108
  wav_file.setnchannels(1)
109
- wav_file.setsampwidth(2)
110
  wav_file.setframerate(sample_rate)
 
111
  wav_file.writeframes(pcm_array.tobytes())
112
 
113
  # Function to detect speech using an energy-based approach
114
  def detect_speech(waveform: torch.Tensor, sample_rate: int, threshold: float = 0.01, min_speech_duration: float = 0.5) -> bool:
 
 
 
 
 
115
  waveform_np = waveform.numpy()
116
  if waveform_np.ndim > 1:
117
- waveform_np = waveform_np.mean(axis=0)
 
 
118
  rms = np.sqrt(np.mean(waveform_np**2))
119
  logger.info(f"RMS energy: {rms}")
 
 
120
  if rms < threshold:
121
  logger.info("No speech detected: RMS energy below threshold")
122
  return False
 
 
 
123
  return True
124
 
125
  # Function to clean up old audio files
126
  def cleanup_old_audio_files():
127
  logger.info("Starting cleanup of old audio files...")
128
- expiration_time = datetime.now() - timedelta(minutes=10)
129
  for filename in os.listdir(AUDIO_DIR):
130
  file_path = os.path.join(AUDIO_DIR, filename)
131
  if os.path.isfile(file_path):
@@ -141,49 +161,53 @@ def cleanup_old_audio_files():
141
  def schedule_cleanup():
142
  while True:
143
  cleanup_old_audio_files()
144
- time.sleep(300)
145
 
146
  # Function to load models in background
147
  def load_models_task():
148
  global models_loaded, loading_in_progress, model_status, error_message
149
- global stt_processor_whisper, stt_model_whisper, stt_processor_mms, stt_model_mms
150
  global mt_model, mt_tokenizer, tts_model, tts_tokenizer
151
 
152
  try:
153
  loading_in_progress = True
 
154
 
155
- # Load STT models
156
- logger.info("Starting to load STT models...")
157
- from transformers import AutoProcessor, AutoModelForCTC, WhisperProcessor, WhisperForConditionalGeneration
158
 
159
  try:
160
  logger.info("Loading Whisper STT model...")
161
- model_status["stt"] = "loading"
162
- stt_processor_whisper = WhisperProcessor.from_pretrained("openai/whisper-tiny")
163
- stt_model_whisper = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
164
- device = "cuda" if torch.cuda.is_available() else "cpu"
165
- stt_model_whisper.to(device)
166
  logger.info("Whisper STT model loaded successfully")
167
- model_status["stt"] = "loaded_whisper"
168
- except Exception as e:
169
- logger.error(f"Failed to load Whisper STT model: {str(e)}")
170
- model_status["stt"] = "failed"
171
- error_message = f"Whisper STT model loading failed: {str(e)}"
172
  return
173
-
 
 
 
 
174
  try:
175
  logger.info("Loading MMS STT model...")
176
- stt_processor_mms = AutoProcessor.from_pretrained("facebook/mms-1b-all")
177
- stt_model_mms = AutoModelForCTC.from_pretrained("facebook/mms-1b-all")
178
- stt_model_mms.to(device)
 
179
  logger.info("MMS STT model loaded successfully")
180
- model_status["stt"] = "loaded_both" if model_status["stt"] == "loaded_whisper" else "loaded_mms"
181
- except Exception as e:
182
- logger.error(f"Failed to load MMS STT model: {str(e)}")
183
- if model_status["stt"] != "loaded_whisper":
184
- model_status["stt"] = "failed"
185
- error_message = f"MMS STT model loading failed: {str(e)}"
186
- return
187
 
188
  # Load MT model
189
  logger.info("Starting to load MT model...")
@@ -203,7 +227,7 @@ def load_models_task():
203
  error_message = f"MT model loading failed: {str(e)}"
204
  return
205
 
206
- # Load TTS model (default to Tagalog)
207
  logger.info("Starting to load TTS model...")
208
  from transformers import VitsModel, AutoTokenizer
209
 
@@ -217,6 +241,7 @@ def load_models_task():
217
  model_status["tts"] = "loaded"
218
  except Exception as e:
219
  logger.error(f"Failed to load TTS model for Tagalog: {str(e)}")
 
220
  try:
221
  logger.info("Falling back to MMS-TTS English model...")
222
  tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
@@ -257,13 +282,21 @@ def start_cleanup_task():
257
 
258
  # Function to load or update TTS model for a specific language
259
  def load_tts_model_for_language(target_code: str) -> bool:
 
 
 
 
260
  global tts_model, tts_tokenizer, current_tts_language, model_status
 
261
  if target_code not in LANGUAGE_MAPPING.values():
262
  logger.error(f"Invalid language code: {target_code}")
263
  return False
 
 
264
  if current_tts_language == target_code and model_status["tts"].startswith("loaded"):
265
  logger.info(f"TTS model for {target_code} is already loaded.")
266
  return True
 
267
  device = "cuda" if torch.cuda.is_available() else "cpu"
268
  try:
269
  logger.info(f"Loading MMS-TTS model for {target_code}...")
@@ -293,21 +326,32 @@ def load_tts_model_for_language(target_code: str) -> bool:
293
 
294
  # Function to synthesize speech from text
295
  def synthesize_speech(text: str, target_code: str) -> Tuple[Optional[str], Optional[str]]:
 
 
 
 
296
  global tts_model, tts_tokenizer
 
297
  request_id = str(uuid.uuid4())
298
  output_path = os.path.join(AUDIO_DIR, f"{request_id}.wav")
 
 
299
  if not load_tts_model_for_language(target_code):
300
  return None, "Failed to load TTS model for the target language"
 
301
  device = "cuda" if torch.cuda.is_available() else "cpu"
302
  try:
303
- inputs = tts_tokenizer(text, return_tensors="pt").toagli(device)
304
  with torch.no_grad():
305
  output = tts_model(**inputs)
306
  speech = output.waveform.cpu().numpy().squeeze()
307
  speech = (speech * 32767).astype(np.int16)
308
  sample_rate = tts_model.config.sampling_rate
 
 
309
  save_pcm_to_wav(speech.tolist(), sample_rate, output_path)
310
  logger.info(f"Saved synthesized audio to {output_path}")
 
311
  return output_path, None
312
  except Exception as e:
313
  error_msg = f"Error during TTS conversion: {str(e)}"
@@ -323,11 +367,14 @@ async def startup_event():
323
 
324
  @app.get("/")
325
  async def root():
 
326
  logger.info("Root endpoint requested")
327
  return {"status": "healthy"}
328
 
329
  @app.get("/health")
330
  async def health_check():
 
 
331
  logger.info("Health check requested")
332
  return {
333
  "status": "healthy",
@@ -339,16 +386,22 @@ async def health_check():
339
 
340
  @app.post("/translate-text")
341
  async def translate_text(text: str = Form(...), source_lang: str = Form(...), target_lang: str = Form(...)):
 
342
  global mt_model, mt_tokenizer, tts_model, tts_tokenizer, current_tts_language
 
343
  if not text:
344
  raise HTTPException(status_code=400, detail="No text provided")
345
  if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
346
  raise HTTPException(status_code=400, detail="Invalid language selected")
 
347
  logger.info(f"Translate-text requested: {text} from {source_lang} to {target_lang}")
348
  request_id = str(uuid.uuid4())
 
 
349
  source_code = LANGUAGE_MAPPING[source_lang]
350
  target_code = LANGUAGE_MAPPING[target_lang]
351
  translated_text = "Translation not available"
 
352
  if model_status["mt"] == "loaded" and mt_model is not None and mt_tokenizer is not None:
353
  try:
354
  source_nllb_code = NLLB_LANGUAGE_CODES[source_code]
@@ -369,20 +422,26 @@ async def translate_text(text: str = Form(...), source_lang: str = Form(...), ta
369
  translated_text = f"Translation failed: {str(e)}"
370
  else:
371
  logger.warning("MT model not loaded, skipping translation")
 
 
372
  is_inappropriate = check_inappropriate_content(text) or check_inappropriate_content(translated_text)
373
  if is_inappropriate:
374
  logger.warning("Inappropriate content detected in translation request")
 
 
375
  output_audio_url = None
376
  if model_status["tts"].startswith("loaded"):
 
377
  if load_tts_model_for_language(target_code):
378
  try:
379
  output_path, error = synthesize_speech(translated_text, target_code)
380
  if output_path:
381
  output_filename = os.path.basename(output_path)
382
- output_audio_url = f"https://jerich-talklasapp2.hf.space/audio_output/{output_filename}"
383
  logger.info("TTS conversion completed")
384
  except Exception as e:
385
  logger.error(f"Error during TTS conversion: {str(e)}")
 
386
  return {
387
  "request_id": request_id,
388
  "status": "completed",
@@ -395,7 +454,8 @@ async def translate_text(text: str = Form(...), source_lang: str = Form(...), ta
395
 
396
  @app.post("/translate-audio")
397
  async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form(...), target_lang: str = Form(...)):
398
- global stt_processor_whisper, stt_model_whisper, stt_processor_mms, stt_model_mms
 
399
  global mt_model, mt_tokenizer, tts_model, tts_tokenizer, current_tts_language
400
 
401
  if not audio:
@@ -403,15 +463,19 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
403
  if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
404
  raise HTTPException(status_code=400, detail="Invalid language selected")
405
 
406
- logger.info(f"Translate-audio requested: {audio.filename} from {source_lang} to {target_lang}")
 
 
 
407
  request_id = str(uuid.uuid4())
408
 
409
- source_code = LANGUAGE_MAPPING[source_lang]
410
- use_whisper = source_code in ["eng", "tgl"]
411
-
412
- # Check if appropriate STT model is loaded
413
- if use_whisper and (stt_processor_whisper is None or stt_model_whisper is None):
414
- logger.warning("Whisper STT model not loaded, returning placeholder response")
 
415
  return {
416
  "request_id": request_id,
417
  "status": "processing",
@@ -421,8 +485,9 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
421
  "output_audio": None,
422
  "is_inappropriate": False
423
  }
424
- elif not use_whisper and (stt_processor_mms is None or stt_model_mms is None):
425
- logger.warning("MMS STT model not loaded, returning placeholder response")
 
426
  return {
427
  "request_id": request_id,
428
  "status": "processing",
@@ -433,6 +498,7 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
433
  "is_inappropriate": False
434
  }
435
 
 
436
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
437
  temp_file.write(await audio.read())
438
  temp_path = temp_file.name
@@ -443,16 +509,19 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
443
  is_inappropriate = False
444
 
445
  try:
 
446
  logger.info(f"Reading audio file: {temp_path}")
447
  waveform, sample_rate = torchaudio.load(temp_path)
448
  logger.info(f"Audio loaded: sample_rate={sample_rate}, waveform_shape={waveform.shape}")
449
 
 
450
  if sample_rate != 16000:
451
  logger.info(f"Resampling audio from {sample_rate} Hz to 16000 Hz")
452
  resampler = torchaudio.transforms.Resample(sample_rate, 16000)
453
  waveform = resampler(waveform)
454
  sample_rate = 16000
455
 
 
456
  if not detect_speech(waveform, sample_rate):
457
  return {
458
  "request_id": request_id,
@@ -464,26 +533,45 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
464
  "is_inappropriate": False
465
  }
466
 
 
467
  device = "cuda" if torch.cuda.is_available() else "cpu"
468
- logger.info(f"Using device: {device}")
469
 
470
  if use_whisper:
471
- logger.info("Using Whisper model for transcription")
472
- whisper_lang = WHISPER_LANGUAGE_MAPPING.get(source_code, "english") # Default to English if not mapped
473
- inputs = stt_processor_whisper(waveform.numpy(), sampling_rate=16000, return_tensors="pt").to(device)
 
 
 
 
474
  with torch.no_grad():
475
- generated_ids = stt_model_whisper.generate(**inputs, language=whisper_lang)
476
- transcription = stt_processor_whisper.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
 
 
 
 
 
 
477
  else:
478
- logger.info("Using MMS model for transcription")
479
- inputs = stt_processor_mms(waveform.numpy(), sampling_rate=16000, return_tensors="pt").to(device)
 
 
 
 
 
480
  with torch.no_grad():
481
- logits = stt_model_mms(**inputs).logits
 
482
  predicted_ids = torch.argmax(logits, dim=-1)
483
- transcription = stt_processor_mms.batch_decode(predicted_ids)[0]
 
484
  logger.info(f"Transcription completed: {transcription}")
485
 
486
- target_code = LANGUAGE_MAPPING[target_lang]
487
  if model_status["mt"] == "loaded" and mt_model is not None and mt_tokenizer is not None:
488
  try:
489
  source_nllb_code = NLLB_LANGUAGE_CODES[source_code]
@@ -504,16 +592,18 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
504
  else:
505
  logger.warning("MT model not loaded, skipping translation")
506
 
 
507
  is_inappropriate = check_inappropriate_content(transcription) or check_inappropriate_content(translated_text)
508
  if is_inappropriate:
509
  logger.warning("Inappropriate content detected in audio transcription or translation")
510
 
 
511
  if load_tts_model_for_language(target_code):
512
  try:
513
  output_path, error = synthesize_speech(translated_text, target_code)
514
  if output_path:
515
  output_filename = os.path.basename(output_path)
516
- output_audio_url = f"https://jerich-talklasapp2.hf.space/audio_output/{output_filename}"
517
  logger.info("TTS conversion completed")
518
  except Exception as e:
519
  logger.error(f"Error during TTS conversion: {str(e)}")
@@ -544,6 +634,7 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
544
 
545
  @app.post("/text-to-speech")
546
  async def text_to_speech(text: str = Form(...), target_lang: str = Form(...)):
 
547
  if not text:
548
  raise HTTPException(status_code=400, detail="No text provided")
549
  if target_lang not in LANGUAGE_MAPPING:
@@ -553,17 +644,20 @@ async def text_to_speech(text: str = Form(...), target_lang: str = Form(...)):
553
  request_id = str(uuid.uuid4())
554
 
555
  target_code = LANGUAGE_MAPPING[target_lang]
 
 
556
  is_inappropriate = check_inappropriate_content(text)
557
  if is_inappropriate:
558
  logger.warning("Inappropriate content detected in text-to-speech request")
559
 
 
560
  output_audio_url = None
561
  if model_status["tts"].startswith("loaded") or load_tts_model_for_language(target_code):
562
  try:
563
  output_path, error = synthesize_speech(text, target_code)
564
  if output_path:
565
  output_filename = os.path.basename(output_path)
566
- output_audio_url = f"https://jerich-talklasapp2.hf.space/audio_output/{output_filename}"
567
  logger.info("TTS conversion completed")
568
  else:
569
  logger.error(f"TTS conversion failed: {error}")
 
17
  from fastapi.responses import JSONResponse
18
  from fastapi.staticfiles import StaticFiles
19
  from typing import Dict, Any, Optional, Tuple, List
20
+ from datetime import datetime, timedelta
21
 
22
  # Configure logging
23
  logging.basicConfig(level=logging.INFO)
 
35
  loading_in_progress = False
36
  loading_thread = None
37
  model_status = {
38
+ "stt_whisper": "not_loaded",
39
+ "stt_mms": "not_loaded",
40
  "mt": "not_loaded",
41
  "tts": "not_loaded"
42
  }
 
44
  current_tts_language = "tgl" # Track the current TTS language
45
 
46
  # Model instances
47
+ whisper_processor = None
48
+ whisper_model = None
49
+ mms_processor = None
50
+ mms_model = None
51
  mt_model = None
52
  mt_tokenizer = None
53
  tts_model = None
 
63
  "Pangasinan": "pag"
64
  }
65
 
66
+ # Define which languages use Whisper vs MMS for STT
67
+ WHISPER_LANGUAGES = {"eng", "tgl"} # English and Tagalog use Whisper
68
+ MMS_LANGUAGES = {"ceb", "ilo", "war", "pag"} # Other Philippine languages use MMS
 
 
69
 
70
  NLLB_LANGUAGE_CODES = {
71
  "eng": "eng_Latn",
 
92
  Check if the text contains inappropriate content.
93
  Returns True if inappropriate content is detected, False otherwise.
94
  """
95
+ # Convert to lowercase for case-insensitive matching
96
  text_lower = text.lower()
97
+
98
+ # Check for inappropriate words
99
  for word in INAPPROPRIATE_WORDS:
100
+ # Use word boundary matching to avoid false positives
101
  pattern = r'\b' + re.escape(word) + r'\b'
102
  if re.search(pattern, text_lower):
103
  logger.warning(f"Inappropriate content detected: {word}")
104
  return True
105
+
106
  return False
107
 
108
  # Function to save PCM data as a WAV file
109
  def save_pcm_to_wav(pcm_data: list, sample_rate: int, output_path: str):
110
+ # Convert pcm_data to a NumPy array of 16-bit integers
111
  pcm_array = np.array(pcm_data, dtype=np.int16)
112
+
113
  with wave.open(output_path, 'wb') as wav_file:
114
+ # Set WAV parameters: 1 channel (mono), 2 bytes per sample (16-bit), sample rate
115
  wav_file.setnchannels(1)
116
+ wav_file.setsampwidth(2) # 16-bit audio
117
  wav_file.setframerate(sample_rate)
118
+ # Write the 16-bit PCM data as bytes (little-endian)
119
  wav_file.writeframes(pcm_array.tobytes())
120
 
121
  # Function to detect speech using an energy-based approach
122
  def detect_speech(waveform: torch.Tensor, sample_rate: int, threshold: float = 0.01, min_speech_duration: float = 0.5) -> bool:
123
+ """
124
+ Detects if the audio contains speech using an energy-based approach.
125
+ Returns True if speech is detected, False otherwise.
126
+ """
127
+ # Convert waveform to numpy array
128
  waveform_np = waveform.numpy()
129
  if waveform_np.ndim > 1:
130
+ waveform_np = waveform_np.mean(axis=0) # Convert stereo to mono
131
+
132
+ # Compute RMS energy
133
  rms = np.sqrt(np.mean(waveform_np**2))
134
  logger.info(f"RMS energy: {rms}")
135
+
136
+ # Check if RMS energy exceeds the threshold
137
  if rms < threshold:
138
  logger.info("No speech detected: RMS energy below threshold")
139
  return False
140
+
141
+ # Optionally, check for minimum speech duration (requires more sophisticated VAD)
142
+ # For now, we assume if RMS is above threshold, there is speech
143
  return True
144
 
145
  # Function to clean up old audio files
146
  def cleanup_old_audio_files():
147
  logger.info("Starting cleanup of old audio files...")
148
+ expiration_time = datetime.now() - timedelta(minutes=10) # Files older than 10 minutes
149
  for filename in os.listdir(AUDIO_DIR):
150
  file_path = os.path.join(AUDIO_DIR, filename)
151
  if os.path.isfile(file_path):
 
161
  def schedule_cleanup():
162
  while True:
163
  cleanup_old_audio_files()
164
+ time.sleep(300) # Run every 5 minutes (300 seconds)
165
 
166
  # Function to load models in background
167
  def load_models_task():
168
  global models_loaded, loading_in_progress, model_status, error_message
169
+ global whisper_processor, whisper_model, mms_processor, mms_model
170
  global mt_model, mt_tokenizer, tts_model, tts_tokenizer
171
 
172
  try:
173
  loading_in_progress = True
174
+ device = "cuda" if torch.cuda.is_available() else "cpu"
175
 
176
+ # Load Whisper STT model for English and Tagalog
177
+ logger.info("Starting to load Whisper STT model...")
178
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
179
 
180
  try:
181
  logger.info("Loading Whisper STT model...")
182
+ model_status["stt_whisper"] = "loading"
183
+ whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-small")
184
+ whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
185
+ whisper_model.to(device)
 
186
  logger.info("Whisper STT model loaded successfully")
187
+ model_status["stt_whisper"] = "loaded"
188
+ except Exception as whisper_error:
189
+ logger.error(f"Failed to load Whisper STT model: {str(whisper_error)}")
190
+ model_status["stt_whisper"] = "failed"
191
+ error_message = f"Whisper STT model loading failed: {str(whisper_error)}"
192
  return
193
+
194
+ # Load MMS STT model for other Philippine languages
195
+ logger.info("Starting to load MMS STT model...")
196
+ from transformers import AutoProcessor, AutoModelForCTC
197
+
198
  try:
199
  logger.info("Loading MMS STT model...")
200
+ model_status["stt_mms"] = "loading"
201
+ mms_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all")
202
+ mms_model = AutoModelForCTC.from_pretrained("facebook/mms-1b-all")
203
+ mms_model.to(device)
204
  logger.info("MMS STT model loaded successfully")
205
+ model_status["stt_mms"] = "loaded"
206
+ except Exception as mms_error:
207
+ logger.error(f"Failed to load MMS STT model: {str(mms_error)}")
208
+ model_status["stt_mms"] = "failed"
209
+ error_message = f"MMS STT model loading failed: {str(mms_error)}"
210
+ return
 
211
 
212
  # Load MT model
213
  logger.info("Starting to load MT model...")
 
227
  error_message = f"MT model loading failed: {str(e)}"
228
  return
229
 
230
+ # Load TTS model (default to Tagalog, will be updated dynamically)
231
  logger.info("Starting to load TTS model...")
232
  from transformers import VitsModel, AutoTokenizer
233
 
 
241
  model_status["tts"] = "loaded"
242
  except Exception as e:
243
  logger.error(f"Failed to load TTS model for Tagalog: {str(e)}")
244
+ # Fallback to English TTS if the target language fails
245
  try:
246
  logger.info("Falling back to MMS-TTS English model...")
247
  tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
 
282
 
283
  # Function to load or update TTS model for a specific language
284
  def load_tts_model_for_language(target_code: str) -> bool:
285
+ """
286
+ Load or update the TTS model for the specified language.
287
+ Returns True if successful, False otherwise.
288
+ """
289
  global tts_model, tts_tokenizer, current_tts_language, model_status
290
+
291
  if target_code not in LANGUAGE_MAPPING.values():
292
  logger.error(f"Invalid language code: {target_code}")
293
  return False
294
+
295
+ # Skip if the model is already loaded for the target language
296
  if current_tts_language == target_code and model_status["tts"].startswith("loaded"):
297
  logger.info(f"TTS model for {target_code} is already loaded.")
298
  return True
299
+
300
  device = "cuda" if torch.cuda.is_available() else "cpu"
301
  try:
302
  logger.info(f"Loading MMS-TTS model for {target_code}...")
 
326
 
327
  # Function to synthesize speech from text
328
  def synthesize_speech(text: str, target_code: str) -> Tuple[Optional[str], Optional[str]]:
329
+ """
330
+ Convert text to speech for the specified language.
331
+ Returns a tuple of (output_path, error_message).
332
+ """
333
  global tts_model, tts_tokenizer
334
+
335
  request_id = str(uuid.uuid4())
336
  output_path = os.path.join(AUDIO_DIR, f"{request_id}.wav")
337
+
338
+ # Make sure the TTS model is loaded for the target language
339
  if not load_tts_model_for_language(target_code):
340
  return None, "Failed to load TTS model for the target language"
341
+
342
  device = "cuda" if torch.cuda.is_available() else "cpu"
343
  try:
344
+ inputs = tts_tokenizer(text, return_tensors="pt").to(device)
345
  with torch.no_grad():
346
  output = tts_model(**inputs)
347
  speech = output.waveform.cpu().numpy().squeeze()
348
  speech = (speech * 32767).astype(np.int16)
349
  sample_rate = tts_model.config.sampling_rate
350
+
351
+ # Save the audio as a WAV file
352
  save_pcm_to_wav(speech.tolist(), sample_rate, output_path)
353
  logger.info(f"Saved synthesized audio to {output_path}")
354
+
355
  return output_path, None
356
  except Exception as e:
357
  error_msg = f"Error during TTS conversion: {str(e)}"
 
367
 
368
  @app.get("/")
369
  async def root():
370
+ """Root endpoint for default health check"""
371
  logger.info("Root endpoint requested")
372
  return {"status": "healthy"}
373
 
374
  @app.get("/health")
375
  async def health_check():
376
+ """Health check endpoint that always returns successfully"""
377
+ global models_loaded, loading_in_progress, model_status, error_message
378
  logger.info("Health check requested")
379
  return {
380
  "status": "healthy",
 
386
 
387
  @app.post("/translate-text")
388
  async def translate_text(text: str = Form(...), source_lang: str = Form(...), target_lang: str = Form(...)):
389
+ """Endpoint to translate text and convert to speech"""
390
  global mt_model, mt_tokenizer, tts_model, tts_tokenizer, current_tts_language
391
+
392
  if not text:
393
  raise HTTPException(status_code=400, detail="No text provided")
394
  if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
395
  raise HTTPException(status_code=400, detail="Invalid language selected")
396
+
397
  logger.info(f"Translate-text requested: {text} from {source_lang} to {target_lang}")
398
  request_id = str(uuid.uuid4())
399
+
400
+ # Translate the text
401
  source_code = LANGUAGE_MAPPING[source_lang]
402
  target_code = LANGUAGE_MAPPING[target_lang]
403
  translated_text = "Translation not available"
404
+
405
  if model_status["mt"] == "loaded" and mt_model is not None and mt_tokenizer is not None:
406
  try:
407
  source_nllb_code = NLLB_LANGUAGE_CODES[source_code]
 
422
  translated_text = f"Translation failed: {str(e)}"
423
  else:
424
  logger.warning("MT model not loaded, skipping translation")
425
+
426
+ # Check for inappropriate content in the source text and translated text
427
  is_inappropriate = check_inappropriate_content(text) or check_inappropriate_content(translated_text)
428
  if is_inappropriate:
429
  logger.warning("Inappropriate content detected in translation request")
430
+
431
+ # Convert translated text to speech
432
  output_audio_url = None
433
  if model_status["tts"].startswith("loaded"):
434
+ # Load or update TTS model for the target language
435
  if load_tts_model_for_language(target_code):
436
  try:
437
  output_path, error = synthesize_speech(translated_text, target_code)
438
  if output_path:
439
  output_filename = os.path.basename(output_path)
440
+ output_audio_url = f"https://jerich-talklasapp.hf.space/audio_output/{output_filename}"
441
  logger.info("TTS conversion completed")
442
  except Exception as e:
443
  logger.error(f"Error during TTS conversion: {str(e)}")
444
+
445
  return {
446
  "request_id": request_id,
447
  "status": "completed",
 
454
 
455
  @app.post("/translate-audio")
456
  async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form(...), target_lang: str = Form(...)):
457
+ """Endpoint to transcribe, translate, and convert audio to speech"""
458
+ global whisper_processor, whisper_model, mms_processor, mms_model
459
  global mt_model, mt_tokenizer, tts_model, tts_tokenizer, current_tts_language
460
 
461
  if not audio:
 
463
  if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
464
  raise HTTPException(status_code=400, detail="Invalid language selected")
465
 
466
+ source_code = LANGUAGE_MAPPING[source_lang]
467
+ target_code = LANGUAGE_MAPPING[target_lang]
468
+
469
+ logger.info(f"Translate-audio requested: {audio.filename} from {source_lang} ({source_code}) to {target_lang} ({target_code})")
470
  request_id = str(uuid.uuid4())
471
 
472
+ # Determine which STT model to use based on source language
473
+ use_whisper = source_code in WHISPER_LANGUAGES
474
+ use_mms = source_code in MMS_LANGUAGES
475
+
476
+ # Check if the appropriate STT model is loaded
477
+ if use_whisper and (model_status["stt_whisper"] != "loaded" or whisper_processor is None or whisper_model is None):
478
+ logger.warning("Whisper STT model not loaded for English/Tagalog, returning placeholder response")
479
  return {
480
  "request_id": request_id,
481
  "status": "processing",
 
485
  "output_audio": None,
486
  "is_inappropriate": False
487
  }
488
+
489
+ if use_mms and (model_status["stt_mms"] != "loaded" or mms_processor is None or mms_model is None):
490
+ logger.warning("MMS STT model not loaded for Philippine languages, returning placeholder response")
491
  return {
492
  "request_id": request_id,
493
  "status": "processing",
 
498
  "is_inappropriate": False
499
  }
500
 
501
+ # Save the uploaded audio to a temporary file
502
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
503
  temp_file.write(await audio.read())
504
  temp_path = temp_file.name
 
509
  is_inappropriate = False
510
 
511
  try:
512
+ # Step 1: Load and resample the audio using torchaudio
513
  logger.info(f"Reading audio file: {temp_path}")
514
  waveform, sample_rate = torchaudio.load(temp_path)
515
  logger.info(f"Audio loaded: sample_rate={sample_rate}, waveform_shape={waveform.shape}")
516
 
517
+ # Resample to 16 kHz if needed (required by Whisper and MMS models)
518
  if sample_rate != 16000:
519
  logger.info(f"Resampling audio from {sample_rate} Hz to 16000 Hz")
520
  resampler = torchaudio.transforms.Resample(sample_rate, 16000)
521
  waveform = resampler(waveform)
522
  sample_rate = 16000
523
 
524
+ # Step 2: Detect speech
525
  if not detect_speech(waveform, sample_rate):
526
  return {
527
  "request_id": request_id,
 
533
  "is_inappropriate": False
534
  }
535
 
536
+ # Step 3: Transcribe the audio (STT)
537
  device = "cuda" if torch.cuda.is_available() else "cpu"
538
+ logger.info(f"Using device: {device} for STT")
539
 
540
  if use_whisper:
541
+ # Use Whisper model for English and Tagalog
542
+ logger.info(f"Using Whisper model for language: {source_code}")
543
+
544
+ # Prepare audio for Whisper
545
+ inputs = whisper_processor(waveform.numpy(), sampling_rate=16000, return_tensors="pt").to(device)
546
+ logger.info("Audio processed for Whisper, generating transcription...")
547
+
548
  with torch.no_grad():
549
+ # For English, we can specify the language; for Tagalog we use 'tl'
550
+ forced_language = "en" if source_code == "eng" else "tl"
551
+ generated_ids = whisper_model.generate(
552
+ **inputs,
553
+ language=forced_language,
554
+ task="transcribe"
555
+ )
556
+ transcription = whisper_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
557
+
558
  else:
559
+ # Use MMS model for other Philippine languages
560
+ logger.info(f"Using MMS model for language: {source_code}")
561
+
562
+ # Prepare audio for MMS
563
+ inputs = mms_processor(waveform.numpy(), sampling_rate=16000, return_tensors="pt").to(device)
564
+ logger.info("Audio processed for MMS, generating transcription...")
565
+
566
  with torch.no_grad():
567
+ # Process with MMS
568
+ logits = mms_model(**inputs).logits
569
  predicted_ids = torch.argmax(logits, dim=-1)
570
+ transcription = mms_processor.batch_decode(predicted_ids)[0]
571
+
572
  logger.info(f"Transcription completed: {transcription}")
573
 
574
+ # Step 4: Translate the transcribed text (MT)
575
  if model_status["mt"] == "loaded" and mt_model is not None and mt_tokenizer is not None:
576
  try:
577
  source_nllb_code = NLLB_LANGUAGE_CODES[source_code]
 
592
  else:
593
  logger.warning("MT model not loaded, skipping translation")
594
 
595
+ # Step 5: Check for inappropriate content
596
  is_inappropriate = check_inappropriate_content(transcription) or check_inappropriate_content(translated_text)
597
  if is_inappropriate:
598
  logger.warning("Inappropriate content detected in audio transcription or translation")
599
 
600
+ # Step 6: Convert translated text to speech (TTS)
601
  if load_tts_model_for_language(target_code):
602
  try:
603
  output_path, error = synthesize_speech(translated_text, target_code)
604
  if output_path:
605
  output_filename = os.path.basename(output_path)
606
+ output_audio_url = f"https://jerich-talklasapp.hf.space/audio_output/{output_filename}"
607
  logger.info("TTS conversion completed")
608
  except Exception as e:
609
  logger.error(f"Error during TTS conversion: {str(e)}")
 
634
 
635
  @app.post("/text-to-speech")
636
  async def text_to_speech(text: str = Form(...), target_lang: str = Form(...)):
637
+ """Endpoint to convert text to speech in the specified language"""
638
  if not text:
639
  raise HTTPException(status_code=400, detail="No text provided")
640
  if target_lang not in LANGUAGE_MAPPING:
 
644
  request_id = str(uuid.uuid4())
645
 
646
  target_code = LANGUAGE_MAPPING[target_lang]
647
+
648
+ # Check for inappropriate content
649
  is_inappropriate = check_inappropriate_content(text)
650
  if is_inappropriate:
651
  logger.warning("Inappropriate content detected in text-to-speech request")
652
 
653
+ # Synthesize speech
654
  output_audio_url = None
655
  if model_status["tts"].startswith("loaded") or load_tts_model_for_language(target_code):
656
  try:
657
  output_path, error = synthesize_speech(text, target_code)
658
  if output_path:
659
  output_filename = os.path.basename(output_path)
660
+ output_audio_url = f"https://jerich-talklasapp.hf.space/audio_output/{output_filename}"
661
  logger.info("TTS conversion completed")
662
  else:
663
  logger.error(f"TTS conversion failed: {error}")