Jerich commited on
Commit
e45bb49
·
verified ·
1 Parent(s): 798e9af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +458 -474
app.py CHANGED
@@ -31,20 +31,24 @@ os.makedirs(AUDIO_DIR, exist_ok=True)
31
  app.mount("/audio_output", StaticFiles(directory=AUDIO_DIR), name="audio_output")
32
 
33
  # Global variables to track application state
34
- model_cache = {
35
- "stt_whisper": {"model": None, "processor": None, "status": "not_loaded"},
36
- "stt_mms": {"model": None, "processor": None, "status": "not_loaded"},
37
- "mt": {"model": None, "tokenizer": None, "status": "not_loaded"},
38
- "tts": {"model": None, "tokenizer": None, "status": "not_loaded", "language": None}
 
 
39
  }
 
 
40
 
41
- # Track loading status
42
- loading_locks = {
43
- "stt_whisper": threading.Lock(),
44
- "stt_mms": threading.Lock(),
45
- "mt": threading.Lock(),
46
- "tts": threading.Lock()
47
- }
48
 
49
  # Define the valid languages and mappings
50
  LANGUAGE_MAPPING = {
@@ -65,26 +69,28 @@ NLLB_LANGUAGE_CODES = {
65
  "pag": "pag_Latn"
66
  }
67
 
68
- # Inappropriate words list - this is a basic implementation
69
- # In a production environment, you would use a more comprehensive solution
70
  INAPPROPRIATE_WORDS = [
71
- "putang", "tang ina", "gago", "puta", "bobo", "ulol", "pakyu", "tae",
72
- "obscenity", "profanity", "explicit", "nsfw", "offensive"
73
  ]
74
 
75
-
76
- # Function to detect inappropriate content
77
- def detect_inappropriate_content(text: str) -> bool:
78
  """
79
- Checks if the text contains any inappropriate words
 
80
  """
 
 
 
81
  text_lower = text.lower()
82
  for word in INAPPROPRIATE_WORDS:
83
- if word in text_lower:
 
84
  return True
85
  return False
86
 
87
-
88
  # Function to save PCM data as a WAV file
89
  def save_pcm_to_wav(pcm_data: list, sample_rate: int, output_path: str):
90
  # Convert pcm_data to a NumPy array of 16-bit integers
@@ -98,7 +104,6 @@ def save_pcm_to_wav(pcm_data: list, sample_rate: int, output_path: str):
98
  # Write the 16-bit PCM data as bytes (little-endian)
99
  wav_file.writeframes(pcm_array.tobytes())
100
 
101
-
102
  # Function to detect speech using an energy-based approach
103
  def detect_speech(waveform: torch.Tensor, sample_rate: int, threshold: float = 0.01, min_speech_duration: float = 0.5) -> bool:
104
  """
@@ -123,7 +128,6 @@ def detect_speech(waveform: torch.Tensor, sample_rate: int, threshold: float = 0
123
  # For now, we assume if RMS is above threshold, there is speech
124
  return True
125
 
126
-
127
  # Function to clean up old audio files
128
  def cleanup_old_audio_files():
129
  logger.info("Starting cleanup of old audio files...")
@@ -139,157 +143,112 @@ def cleanup_old_audio_files():
139
  except Exception as e:
140
  logger.error(f"Error deleting file {file_path}: {str(e)}")
141
 
142
-
143
  # Background task to periodically clean up audio files
144
  def schedule_cleanup():
145
  while True:
146
  cleanup_old_audio_files()
147
  time.sleep(300) # Run every 5 minutes (300 seconds)
148
 
149
-
150
- # Function to load the Whisper STT model on demand
151
- def load_whisper_model():
152
- if model_cache["stt_whisper"]["status"] == "loaded":
153
- return True
154
-
155
- # Use lock to prevent multiple threads from loading the model simultaneously
156
- if not loading_locks["stt_whisper"].acquire(blocking=False):
157
- logger.info("Whisper model loading already in progress")
158
- return False
159
-
160
- try:
161
- logger.info("Loading Whisper small model...")
162
- model_cache["stt_whisper"]["status"] = "loading"
163
-
164
- from transformers import WhisperProcessor, WhisperForConditionalGeneration
165
- device = "cuda" if torch.cuda.is_available() else "cpu"
166
-
167
- model_cache["stt_whisper"]["processor"] = WhisperProcessor.from_pretrained("openai/whisper-small")
168
- model_cache["stt_whisper"]["model"] = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
169
- model_cache["stt_whisper"]["model"].to(device)
170
-
171
- model_cache["stt_whisper"]["status"] = "loaded"
172
- logger.info("Whisper small model loaded successfully")
173
- return True
174
- except Exception as e:
175
- model_cache["stt_whisper"]["status"] = "failed"
176
- logger.error(f"Failed to load Whisper model: {str(e)}")
177
- return False
178
- finally:
179
- loading_locks["stt_whisper"].release()
180
-
181
-
182
- # Function to load the MMS STT model on demand
183
- def load_mms_stt_model():
184
- if model_cache["stt_mms"]["status"] == "loaded":
185
- return True
186
-
187
- if not loading_locks["stt_mms"].acquire(blocking=False):
188
- logger.info("MMS STT model loading already in progress")
189
- return False
190
 
191
  try:
192
- logger.info("Loading MMS STT model...")
193
- model_cache["stt_mms"]["status"] = "loading"
194
 
195
- from transformers import AutoProcessor, AutoModelForCTC
196
- device = "cuda" if torch.cuda.is_available() else "cpu"
197
-
198
- model_cache["stt_mms"]["processor"] = AutoProcessor.from_pretrained("facebook/mms-1b-all")
199
- model_cache["stt_mms"]["model"] = AutoModelForCTC.from_pretrained("facebook/mms-1b-all")
200
- model_cache["stt_mms"]["model"].to(device)
201
-
202
- model_cache["stt_mms"]["status"] = "loaded"
203
- logger.info("MMS STT model loaded successfully")
204
- return True
205
- except Exception as e:
206
- model_cache["stt_mms"]["status"] = "failed"
207
- logger.error(f"Failed to load MMS STT model: {str(e)}")
208
- return False
209
- finally:
210
- loading_locks["stt_mms"].release()
211
-
212
-
213
- # Function to load the MT model on demand
214
- def load_mt_model():
215
- if model_cache["mt"]["status"] == "loaded":
216
- return True
217
-
218
- if not loading_locks["mt"].acquire(blocking=False):
219
- logger.info("MT model loading already in progress")
220
- return False
221
-
222
- try:
223
- logger.info("Loading NLLB-200-distilled-600M model...")
224
- model_cache["mt"]["status"] = "loading"
225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
227
- device = "cuda" if torch.cuda.is_available() else "cpu"
228
 
229
- model_cache["mt"]["model"] = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
230
- model_cache["mt"]["tokenizer"] = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
231
- model_cache["mt"]["model"].to(device)
232
-
233
- model_cache["mt"]["status"] = "loaded"
234
- logger.info("MT model loaded successfully")
235
- return True
236
- except Exception as e:
237
- model_cache["mt"]["status"] = "failed"
238
- logger.error(f"Failed to load MT model: {str(e)}")
239
- return False
240
- finally:
241
- loading_locks["mt"].release()
242
-
243
 
244
- # Function to load the TTS model for a specific language on demand
245
- def load_tts_model(language_code: str):
246
- # If the model is already loaded for this language, return immediately
247
- if (model_cache["tts"]["status"] == "loaded" and
248
- model_cache["tts"]["language"] == language_code):
249
- return True
250
-
251
- if not loading_locks["tts"].acquire(blocking=False):
252
- logger.info("TTS model loading already in progress")
253
- return False
254
-
255
- try:
256
- logger.info(f"Loading MMS-TTS model for {language_code}...")
257
- model_cache["tts"]["status"] = "loading"
258
-
259
  from transformers import VitsModel, AutoTokenizer
260
- device = "cuda" if torch.cuda.is_available() else "cpu"
261
 
262
  try:
263
- model_cache["tts"]["model"] = VitsModel.from_pretrained(f"facebook/mms-tts-{language_code}")
264
- model_cache["tts"]["tokenizer"] = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{language_code}")
265
- model_cache["tts"]["model"].to(device)
266
- model_cache["tts"]["language"] = language_code
267
- model_cache["tts"]["status"] = "loaded"
268
- logger.info(f"TTS model for {language_code} loaded successfully")
269
- return True
270
  except Exception as e:
271
- logger.error(f"Failed to load TTS model for {language_code}: {str(e)}")
272
  # Fallback to English TTS if the target language fails
273
  try:
274
  logger.info("Falling back to MMS-TTS English model...")
275
- model_cache["tts"]["model"] = VitsModel.from_pretrained("facebook/mms-tts-eng")
276
- model_cache["tts"]["tokenizer"] = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
277
- model_cache["tts"]["model"].to(device)
278
- model_cache["tts"]["language"] = "eng"
279
- model_cache["tts"]["status"] = "loaded (fallback)"
280
  logger.info("Fallback TTS model loaded successfully")
281
- return True
 
282
  except Exception as e2:
283
- model_cache["tts"]["status"] = "failed"
284
  logger.error(f"Failed to load fallback TTS model: {str(e2)}")
285
- return False
 
 
 
 
 
 
286
  except Exception as e:
287
- model_cache["tts"]["status"] = "failed"
288
- logger.error(f"Failed to setup TTS model: {str(e)}")
289
- return False
290
  finally:
291
- loading_locks["tts"].release()
292
 
 
 
 
 
 
 
 
 
293
 
294
  # Start the background cleanup task
295
  def start_cleanup_task():
@@ -297,130 +256,116 @@ def start_cleanup_task():
297
  cleanup_thread.daemon = True
298
  cleanup_thread.start()
299
 
300
-
301
  # Start the background processes when the app starts
302
  @app.on_event("startup")
303
  async def startup_event():
304
  logger.info("Application starting up...")
 
305
  start_cleanup_task()
306
 
307
-
308
  @app.get("/")
309
  async def root():
310
  """Root endpoint for default health check"""
311
  logger.info("Root endpoint requested")
312
  return {"status": "healthy"}
313
 
314
-
315
  @app.get("/health")
316
  async def health_check():
317
  """Health check endpoint that always returns successfully"""
 
318
  logger.info("Health check requested")
319
  return {
320
  "status": "healthy",
321
- "model_status": {
322
- "stt_whisper": model_cache["stt_whisper"]["status"],
323
- "stt_mms": model_cache["stt_mms"]["status"],
324
- "mt": model_cache["mt"]["status"],
325
- "tts": model_cache["tts"]["status"],
326
- "tts_language": model_cache["tts"]["language"]
327
- }
328
  }
329
 
330
-
331
  @app.post("/update-languages")
332
  async def update_languages(source_lang: str = Form(...), target_lang: str = Form(...)):
333
- """
334
- Update the language settings for translation services
335
- Will trigger loading of necessary models if not already loaded
336
- """
337
  if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
338
  raise HTTPException(status_code=400, detail="Invalid language selected")
339
 
340
  source_code = LANGUAGE_MAPPING[source_lang]
341
  target_code = LANGUAGE_MAPPING[target_lang]
342
 
343
- # Determine which STT model to use based on the source language
344
- if source_code in ["eng", "tgl"]:
345
- # Load Whisper for English or Tagalog
346
- if not load_whisper_model():
347
- return {"status": "pending", "message": "Whisper model loading in progress"}
348
- else:
349
- # Load MMS for other Philippine languages
350
- if not load_mms_stt_model():
351
- return {"status": "pending", "message": "MMS STT model loading in progress"}
352
-
353
- # Load the MT model if not already loaded
354
- if not load_mt_model():
355
- return {"status": "pending", "message": "MT model loading in progress"}
356
-
357
- # Load the appropriate TTS model for the target language
358
- if not load_tts_model(target_code):
359
- return {"status": "pending", "message": "TTS model loading in progress"}
360
-
361
- logger.info(f"Languages updated to {source_lang} → {target_lang}")
362
- return {"status": "success", "message": f"Languages updated to {source_lang} → {target_lang}"}
363
-
364
-
365
- @app.post("/synthesize-speech")
366
- async def synthesize_speech(text: str = Form(...), language: str = Form(...)):
367
- """Endpoint to synthesize speech from text without translation"""
368
- if language not in LANGUAGE_MAPPING:
369
- raise HTTPException(status_code=400, detail="Invalid language selected")
370
-
371
- language_code = LANGUAGE_MAPPING[language]
372
- request_id = str(uuid.uuid4())
373
-
374
- # Load the TTS model for the requested language
375
- if not load_tts_model(language_code):
376
- return {
377
- "request_id": request_id,
378
- "status": "pending",
379
- "message": "TTS model loading in progress. Please try again in a moment."
380
- }
381
-
382
  try:
 
 
383
  device = "cuda" if torch.cuda.is_available() else "cpu"
384
- inputs = model_cache["tts"]["tokenizer"](text, return_tensors="pt").to(device)
385
 
386
- with torch.no_grad():
387
- output = model_cache["tts"]["model"](**inputs)
388
-
389
- speech = output.waveform.cpu().numpy().squeeze()
390
- speech = (speech * 32767).astype(np.int16)
391
- sample_rate = model_cache["tts"]["model"].config.sampling_rate
392
-
393
- # Save the audio as a WAV file
394
- output_filename = f"{request_id}.wav"
395
- output_path = os.path.join(AUDIO_DIR, output_filename)
396
- save_pcm_to_wav(speech.tolist(), sample_rate, output_path)
397
- logger.info(f"Saved synthesized audio to {output_path}")
398
-
399
- # Generate a URL to the WAV file
400
- output_audio_url = f"https://jerich-talklasapp.hf.space/audio_output/{output_filename}"
401
-
402
- return {
403
- "request_id": request_id,
404
- "status": "completed",
405
- "message": "Speech synthesis completed successfully",
406
- "text": text,
407
- "output_audio": output_audio_url
408
- }
 
 
 
 
 
 
 
 
 
 
409
 
 
 
 
 
 
 
 
 
 
 
410
  except Exception as e:
411
- logger.error(f"Error during speech synthesis: {str(e)}")
412
- return {
413
- "request_id": request_id,
414
- "status": "failed",
415
- "message": f"Speech synthesis failed: {str(e)}",
416
- "text": text,
417
- "output_audio": None
418
- }
419
-
 
 
 
 
 
 
 
 
420
 
421
  @app.post("/translate-text")
422
  async def translate_text(text: str = Form(...), source_lang: str = Form(...), target_lang: str = Form(...)):
423
  """Endpoint to translate text and convert to speech"""
 
 
424
  if not text:
425
  raise HTTPException(status_code=400, detail="No text provided")
426
  if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
@@ -429,107 +374,100 @@ async def translate_text(text: str = Form(...), source_lang: str = Form(...), ta
429
  logger.info(f"Translate-text requested: {text} from {source_lang} to {target_lang}")
430
  request_id = str(uuid.uuid4())
431
 
432
- # Load the MT model if not already loaded
433
- if not load_mt_model():
434
- return {
435
- "request_id": request_id,
436
- "status": "pending",
437
- "message": "MT model loading in progress. Please try again in a moment.",
438
- "source_text": text,
439
- "translated_text": "Translation not available yet",
440
- "output_audio": None,
441
- "contains_inappropriate_content": False
442
- }
443
-
444
  # Translate the text
445
  source_code = LANGUAGE_MAPPING[source_lang]
446
  target_code = LANGUAGE_MAPPING[target_lang]
447
  translated_text = "Translation not available"
448
- contains_inappropriate = False
449
 
450
- try:
451
- source_nllb_code = NLLB_LANGUAGE_CODES[source_code]
452
- target_nllb_code = NLLB_LANGUAGE_CODES[target_code]
453
- model_cache["mt"]["tokenizer"].src_lang = source_nllb_code
454
- device = "cuda" if torch.cuda.is_available() else "cpu"
455
- inputs = model_cache["mt"]["tokenizer"](text, return_tensors="pt").to(device)
456
- with torch.no_grad():
457
- generated_tokens = model_cache["mt"]["model"].generate(
458
- **inputs,
459
- forced_bos_token_id=model_cache["mt"]["tokenizer"].convert_tokens_to_ids(target_nllb_code),
460
- max_length=448
461
- )
462
- translated_text = model_cache["mt"]["tokenizer"].batch_decode(generated_tokens, skip_special_tokens=True)[0]
463
- logger.info(f"Translation completed: {translated_text}")
464
-
465
- # Check for inappropriate content
466
- contains_inappropriate = detect_inappropriate_content(translated_text)
467
- if contains_inappropriate:
468
- logger.warning(f"Inappropriate content detected in translation")
469
-
470
- except Exception as e:
471
- logger.error(f"Error during translation: {str(e)}")
472
- translated_text = f"Translation failed: {str(e)}"
473
- return {
474
- "request_id": request_id,
475
- "status": "failed",
476
- "message": f"Translation failed: {str(e)}",
477
- "source_text": text,
478
- "translated_text": translated_text,
479
- "output_audio": None,
480
- "contains_inappropriate_content": contains_inappropriate
481
- }
482
 
483
- # Load the TTS model for the target language
484
- if not load_tts_model(target_code):
485
- return {
486
- "request_id": request_id,
487
- "status": "partial",
488
- "message": "Translation completed, but TTS model is loading. Please try again for audio.",
489
- "source_text": text,
490
- "translated_text": translated_text,
491
- "output_audio": None,
492
- "contains_inappropriate_content": contains_inappropriate
493
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
494
 
495
  # Convert translated text to speech
496
  output_audio_url = None
497
- try:
498
- inputs = model_cache["tts"]["tokenizer"](translated_text, return_tensors="pt").to(device)
499
- with torch.no_grad():
500
- output = model_cache["tts"]["model"](**inputs)
501
- speech = output.waveform.cpu().numpy().squeeze()
502
- speech = (speech * 32767).astype(np.int16)
503
- sample_rate = model_cache["tts"]["model"].config.sampling_rate
 
504
 
505
- # Save the audio as a WAV file
506
- output_filename = f"{request_id}.wav"
507
- output_path = os.path.join(AUDIO_DIR, output_filename)
508
- save_pcm_to_wav(speech.tolist(), sample_rate, output_path)
509
- logger.info(f"Saved synthesized audio to {output_path}")
510
 
511
- # Generate a URL to the WAV file
512
- output_audio_url = f"https://jerich-talklasapp.hf.space/audio_output/{output_filename}"
513
- logger.info("TTS conversion completed")
514
- except Exception as e:
515
- logger.error(f"Error during TTS conversion: {str(e)}")
516
- output_audio_url = None
517
 
518
  return {
519
  "request_id": request_id,
520
- "status": "completed" if output_audio_url else "partial",
521
- "message": "Translation and TTS completed" if output_audio_url else
522
- "Translation completed but TTS failed",
523
  "source_text": text,
524
  "translated_text": translated_text,
525
- "output_audio": output_audio_url,
526
- "contains_inappropriate_content": contains_inappropriate
527
  }
528
 
529
-
530
  @app.post("/translate-audio")
531
  async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form(...), target_lang: str = Form(...)):
532
  """Endpoint to transcribe, translate, and convert audio to speech"""
 
 
533
  if not audio:
534
  raise HTTPException(status_code=400, detail="No audio file provided")
535
  if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
@@ -538,35 +476,18 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
538
  logger.info(f"Translate-audio requested: {audio.filename} from {source_lang} to {target_lang}")
539
  request_id = str(uuid.uuid4())
540
 
541
- source_code = LANGUAGE_MAPPING[source_lang]
542
- target_code = LANGUAGE_MAPPING[target_lang]
543
-
544
- # Determine which STT model to use based on source language
545
- use_whisper = source_code in ["eng", "tgl"]
546
-
547
- # Ensure the appropriate STT model is loaded
548
- if use_whisper:
549
- if not load_whisper_model():
550
- return {
551
- "request_id": request_id,
552
- "status": "pending",
553
- "message": "Whisper STT model loading in progress. Please try again in a moment.",
554
- "source_text": "Transcription not available yet",
555
- "translated_text": "Translation not available yet",
556
- "output_audio": None,
557
- "contains_inappropriate_content": False
558
- }
559
- else:
560
- if not load_mms_stt_model():
561
- return {
562
- "request_id": request_id,
563
- "status": "pending",
564
- "message": "MMS STT model loading in progress. Please try again in a moment.",
565
- "source_text": "Transcription not available yet",
566
- "translated_text": "Translation not available yet",
567
- "output_audio": None,
568
- "contains_inappropriate_content": False
569
- }
570
 
571
  # Save the uploaded audio to a temporary file
572
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
@@ -576,7 +497,7 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
576
  transcription = "Transcription not available"
577
  translated_text = "Translation not available"
578
  output_audio_url = None
579
- contains_inappropriate = False
580
 
581
  try:
582
  # Step 1: Load and resample the audio using torchaudio
@@ -599,133 +520,112 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
599
  "message": "No speech detected in the audio.",
600
  "source_text": "No speech detected",
601
  "translated_text": "No translation available",
602
- "output_audio": None,
603
- "contains_inappropriate_content": False
604
  }
605
 
606
  # Step 3: Transcribe the audio (STT)
607
  device = "cuda" if torch.cuda.is_available() else "cpu"
608
  logger.info(f"Using device: {device}")
 
 
609
 
610
- if use_whisper:
611
- # Use Whisper for English/Tagalog
612
- stt_processor = model_cache["stt_whisper"]["processor"]
613
- stt_model = model_cache["stt_whisper"]["model"]
614
-
615
- inputs = stt_processor(waveform.numpy(), sampling_rate=16000, return_tensors="pt").to(device)
616
- logger.info("Audio processed with Whisper, generating transcription...")
617
-
618
- with torch.no_grad():
619
- generated_ids = stt_model.generate(**inputs, language="en" if source_code == "eng" else "tl")
620
  transcription = stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
621
- else:
622
- # Use MMS for other Philippine languages
623
- stt_processor = model_cache["stt_mms"]["processor"]
624
- stt_model = model_cache["stt_mms"]["model"]
625
-
626
- # Set the target language for MMS if supported
627
- if source_code in stt_processor.tokenizer.vocab.keys():
628
- stt_processor.tokenizer.set_target_lang(source_code)
629
- stt_model.load_adapter(source_code)
630
-
631
- inputs = stt_processor(waveform.numpy(), sampling_rate=16000, return_tensors="pt").to(device)
632
- logger.info("Audio processed with MMS, generating transcription...")
633
-
634
- with torch.no_grad():
635
  logits = stt_model(**inputs).logits
636
  predicted_ids = torch.argmax(logits, dim=-1)
637
  transcription = stt_processor.batch_decode(predicted_ids)[0]
638
-
639
  logger.info(f"Transcription completed: {transcription}")
640
 
641
- # Step 4: Load the MT model if not already loaded
642
- if not load_mt_model():
643
- return {
644
- "request_id": request_id,
645
- "status": "partial",
646
- "message": "Transcription completed, but MT model is loading. Please try again for translation.",
647
- "source_text": transcription,
648
- "translated_text": "Translation not available yet",
649
- "output_audio": None,
650
- "contains_inappropriate_content": False
651
- }
652
-
653
- # Step 5: Translate the transcribed text (MT)
654
- try:
655
- source_nllb_code = NLLB_LANGUAGE_CODES[source_code]
656
- target_nllb_code = NLLB_LANGUAGE_CODES[target_code]
657
- model_cache["mt"]["tokenizer"].src_lang = source_nllb_code
658
-
659
- inputs = model_cache["mt"]["tokenizer"](transcription, return_tensors="pt").to(device)
660
- with torch.no_grad():
661
- generated_tokens = model_cache["mt"]["model"].generate(
662
- **inputs,
663
- forced_bos_token_id=model_cache["mt"]["tokenizer"].convert_tokens_to_ids(target_nllb_code),
664
- max_length=448
665
- )
666
- translated_text = model_cache["mt"]["tokenizer"].batch_decode(generated_tokens, skip_special_tokens=True)[0]
667
- logger.info(f"Translation completed: {translated_text}")
668
-
669
- # Check for inappropriate content
670
- contains_inappropriate = detect_inappropriate_content(translated_text)
671
- if contains_inappropriate:
672
- logger.warning(f"Inappropriate content detected in translation")
673
 
674
- except Exception as e:
675
- logger.error(f"Error during translation: {str(e)}")
676
- translated_text = f"Translation failed: {str(e)}"
677
- return {
678
- "request_id": request_id,
679
- "status": "partial",
680
- "message": f"Transcription completed but translation failed: {str(e)}",
681
- "source_text": transcription,
682
- "translated_text": translated_text,
683
- "output_audio": None,
684
- "contains_inappropriate_content": False
685
- }
686
-
687
- # Step 6: Load the TTS model for the target language
688
- if not load_tts_model(target_code):
689
- return {
690
- "request_id": request_id,
691
- "status": "partial",
692
- "message": "Transcription and translation completed, but TTS model is loading.",
693
- "source_text": transcription,
694
- "translated_text": translated_text,
695
- "output_audio": None,
696
- "contains_inappropriate_content": contains_inappropriate
697
- }
698
-
699
- # Step 7: Convert translated text to speech (TTS)
700
- try:
701
- inputs = model_cache["tts"]["tokenizer"](translated_text, return_tensors="pt").to(device)
702
- with torch.no_grad():
703
- output = model_cache["tts"]["model"](**inputs)
704
- speech = output.waveform.cpu().numpy().squeeze()
705
- speech = (speech * 32767).astype(np.int16)
706
- sample_rate = model_cache["tts"]["model"].config.sampling_rate
707
- # Save the audio as a WAV file
708
- output_filename = f"{request_id}.wav"
709
- output_path = os.path.join(AUDIO_DIR, output_filename)
710
- save_pcm_to_wav(speech.tolist(), sample_rate, output_path)
711
- logger.info(f"Saved synthesized audio to {output_path}")
712
 
713
- # Generate a URL to the WAV file
714
- output_audio_url = f"https://jerich-talklasapp.hf.space/audio_output/{output_filename}"
715
- logger.info("TTS conversion completed")
716
- except Exception as e:
717
- logger.error(f"Error during TTS conversion: {str(e)}")
718
- output_audio_url = None
719
-
720
- return {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
721
  "request_id": request_id,
722
- "status": "completed" if output_audio_url else "partial",
723
- "message": "Transcription, translation, and TTS completed" if output_audio_url else
724
- "Transcription and translation completed but TTS failed",
725
  "source_text": transcription,
726
  "translated_text": translated_text,
727
- "output_audio": output_audio_url,
728
- "contains_inappropriate_content": contains_inappropriate
729
  }
730
  except Exception as e:
731
  logger.error(f"Error during processing: {str(e)}")
@@ -735,29 +635,113 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
735
  "message": f"Processing failed: {str(e)}",
736
  "source_text": transcription,
737
  "translated_text": translated_text,
738
- "output_audio": output_audio_url,
739
- "contains_inappropriate_content": contains_inappropriate
740
  }
741
  finally:
742
  logger.info(f"Cleaning up temporary file: {temp_path}")
743
- try:
744
- os.unlink(temp_path)
745
- except Exception as e:
746
- logger.error(f"Error deleting temporary file: {str(e)}")
747
-
748
 
749
- # Add a method to check if text contains inappropriate content
750
- @app.post("/check-content")
751
- async def check_content(text: str = Form(...)):
752
- """
753
- Check if the provided text contains inappropriate content
754
- """
755
- contains_inappropriate = detect_inappropriate_content(text)
756
- return {
757
- "text": text,
758
- "contains_inappropriate_content": contains_inappropriate
759
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
760
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
761
 
762
  if __name__ == "__main__":
763
  import uvicorn
 
31
  app.mount("/audio_output", StaticFiles(directory=AUDIO_DIR), name="audio_output")
32
 
33
  # Global variables to track application state
34
+ models_loaded = False
35
+ loading_in_progress = False
36
+ loading_thread = None
37
+ model_status = {
38
+ "stt": "not_loaded",
39
+ "mt": "not_loaded",
40
+ "tts": "not_loaded"
41
  }
42
+ error_message = None
43
+ current_tts_language = "tgl" # Track the current TTS language
44
 
45
+ # Model instances
46
+ stt_processor = None
47
+ stt_model = None
48
+ mt_model = None
49
+ mt_tokenizer = None
50
+ tts_model = None
51
+ tts_tokenizer = None
52
 
53
  # Define the valid languages and mappings
54
  LANGUAGE_MAPPING = {
 
69
  "pag": "pag_Latn"
70
  }
71
 
72
+ # Define a list of inappropriate words for content filtering
 
73
  INAPPROPRIATE_WORDS = [
74
+ "profanity", "obscenity", "obscene", "offensive", "vulgar", "explicit",
75
+ # Add more words as needed or load from a separate file
76
  ]
77
 
78
+ # Function to check if text contains inappropriate content
79
+ def check_inappropriate_content(text: str) -> bool:
 
80
  """
81
+ Check if the given text contains inappropriate words.
82
+ Returns True if inappropriate content is detected, False otherwise.
83
  """
84
+ if not text:
85
+ return False
86
+
87
  text_lower = text.lower()
88
  for word in INAPPROPRIATE_WORDS:
89
+ if re.search(r'\b' + re.escape(word) + r'\b', text_lower):
90
+ logger.warning(f"Inappropriate content detected: '{word}' in text")
91
  return True
92
  return False
93
 
 
94
  # Function to save PCM data as a WAV file
95
  def save_pcm_to_wav(pcm_data: list, sample_rate: int, output_path: str):
96
  # Convert pcm_data to a NumPy array of 16-bit integers
 
104
  # Write the 16-bit PCM data as bytes (little-endian)
105
  wav_file.writeframes(pcm_array.tobytes())
106
 
 
107
  # Function to detect speech using an energy-based approach
108
  def detect_speech(waveform: torch.Tensor, sample_rate: int, threshold: float = 0.01, min_speech_duration: float = 0.5) -> bool:
109
  """
 
128
  # For now, we assume if RMS is above threshold, there is speech
129
  return True
130
 
 
131
  # Function to clean up old audio files
132
  def cleanup_old_audio_files():
133
  logger.info("Starting cleanup of old audio files...")
 
143
  except Exception as e:
144
  logger.error(f"Error deleting file {file_path}: {str(e)}")
145
 
 
146
  # Background task to periodically clean up audio files
147
  def schedule_cleanup():
148
  while True:
149
  cleanup_old_audio_files()
150
  time.sleep(300) # Run every 5 minutes (300 seconds)
151
 
152
+ # Function to load models in background
153
+ def load_models_task():
154
+ global models_loaded, loading_in_progress, model_status, error_message
155
+ global stt_processor, stt_model, mt_model, mt_tokenizer, tts_model, tts_tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  try:
158
+ loading_in_progress = True
 
159
 
160
+ # Load STT model (MMS with fallback to Whisper)
161
+ logger.info("Starting to load STT model...")
162
+ from transformers import AutoProcessor, AutoModelForCTC, WhisperProcessor, WhisperForConditionalGeneration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
+ try:
165
+ logger.info("Loading MMS STT model...")
166
+ model_status["stt"] = "loading"
167
+ stt_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all")
168
+ stt_model = AutoModelForCTC.from_pretrained("facebook/mms-1b-all")
169
+ device = "cuda" if torch.cuda.is_available() else "cpu"
170
+ stt_model.to(device)
171
+ logger.info("MMS STT model loaded successfully")
172
+ model_status["stt"] = "loaded_mms"
173
+ except Exception as mms_error:
174
+ logger.error(f"Failed to load MMS STT model: {str(mms_error)}")
175
+ logger.info("Falling back to Whisper STT model...")
176
+ try:
177
+ stt_processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
178
+ stt_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
179
+ stt_model.to(device)
180
+ logger.info("Whisper STT model loaded successfully as fallback")
181
+ model_status["stt"] = "loaded_whisper"
182
+ except Exception as whisper_error:
183
+ logger.error(f"Failed to load Whisper STT model: {str(whisper_error)}")
184
+ model_status["stt"] = "failed"
185
+ error_message = f"STT model loading failed: MMS error: {str(mms_error)}, Whisper error: {str(whisper_error)}"
186
+ return
187
+
188
+ # Load MT model
189
+ logger.info("Starting to load MT model...")
190
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
 
191
 
192
+ try:
193
+ logger.info("Loading NLLB-200-distilled-600M model...")
194
+ model_status["mt"] = "loading"
195
+ mt_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
196
+ mt_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
197
+ mt_model.to(device)
198
+ logger.info("MT model loaded successfully")
199
+ model_status["mt"] = "loaded"
200
+ except Exception as e:
201
+ logger.error(f"Failed to load MT model: {str(e)}")
202
+ model_status["mt"] = "failed"
203
+ error_message = f"MT model loading failed: {str(e)}"
204
+ return
 
205
 
206
+ # Load TTS model (default to Tagalog, will be updated dynamically)
207
+ logger.info("Starting to load TTS model...")
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  from transformers import VitsModel, AutoTokenizer
 
209
 
210
  try:
211
+ logger.info("Loading MMS-TTS model for Tagalog...")
212
+ model_status["tts"] = "loading"
213
+ tts_model = VitsModel.from_pretrained("facebook/mms-tts-tgl")
214
+ tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-tgl")
215
+ tts_model.to(device)
216
+ logger.info("TTS model loaded successfully")
217
+ model_status["tts"] = "loaded"
218
  except Exception as e:
219
+ logger.error(f"Failed to load TTS model for Tagalog: {str(e)}")
220
  # Fallback to English TTS if the target language fails
221
  try:
222
  logger.info("Falling back to MMS-TTS English model...")
223
+ tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
224
+ tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
225
+ tts_model.to(device)
 
 
226
  logger.info("Fallback TTS model loaded successfully")
227
+ model_status["tts"] = "loaded (fallback)"
228
+ current_tts_language = "eng"
229
  except Exception as e2:
 
230
  logger.error(f"Failed to load fallback TTS model: {str(e2)}")
231
+ model_status["tts"] = "failed"
232
+ error_message = f"TTS model loading failed: {str(e)} (fallback also failed: {str(e2)})"
233
+ return
234
+
235
+ models_loaded = True
236
+ logger.info("Model loading completed successfully")
237
+
238
  except Exception as e:
239
+ error_message = str(e)
240
+ logger.error(f"Error in model loading task: {str(e)}")
 
241
  finally:
242
+ loading_in_progress = False
243
 
244
+ # Start loading models in background
245
+ def start_model_loading():
246
+ global loading_thread, loading_in_progress
247
+ if not loading_in_progress and not models_loaded:
248
+ loading_in_progress = True
249
+ loading_thread = threading.Thread(target=load_models_task)
250
+ loading_thread.daemon = True
251
+ loading_thread.start()
252
 
253
  # Start the background cleanup task
254
  def start_cleanup_task():
 
256
  cleanup_thread.daemon = True
257
  cleanup_thread.start()
258
 
 
259
  # Start the background processes when the app starts
260
  @app.on_event("startup")
261
  async def startup_event():
262
  logger.info("Application starting up...")
263
+ start_model_loading()
264
  start_cleanup_task()
265
 
 
266
  @app.get("/")
267
  async def root():
268
  """Root endpoint for default health check"""
269
  logger.info("Root endpoint requested")
270
  return {"status": "healthy"}
271
 
 
272
  @app.get("/health")
273
  async def health_check():
274
  """Health check endpoint that always returns successfully"""
275
+ global models_loaded, loading_in_progress, model_status, error_message
276
  logger.info("Health check requested")
277
  return {
278
  "status": "healthy",
279
+ "models_loaded": models_loaded,
280
+ "loading_in_progress": loading_in_progress,
281
+ "model_status": model_status,
282
+ "error": error_message
 
 
 
283
  }
284
 
 
285
  @app.post("/update-languages")
286
  async def update_languages(source_lang: str = Form(...), target_lang: str = Form(...)):
287
+ global stt_processor, stt_model, tts_model, tts_tokenizer, current_tts_language
288
+
 
 
289
  if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
290
  raise HTTPException(status_code=400, detail="Invalid language selected")
291
 
292
  source_code = LANGUAGE_MAPPING[source_lang]
293
  target_code = LANGUAGE_MAPPING[target_lang]
294
 
295
+ # Update the STT model based on the source language (MMS or Whisper)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  try:
297
+ logger.info("Updating STT model for source language...")
298
+ from transformers import AutoProcessor, AutoModelForCTC, WhisperProcessor, WhisperForConditionalGeneration
299
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
300
 
301
+ try:
302
+ logger.info(f"Loading MMS STT model for {source_code}...")
303
+ stt_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all")
304
+ stt_model = AutoModelForCTC.from_pretrained("facebook/mms-1b-all")
305
+ stt_model.to(device)
306
+ # Set the target language for MMS
307
+ if source_code in stt_processor.tokenizer.vocab.keys():
308
+ stt_processor.tokenizer.set_target_lang(source_code)
309
+ stt_model.load_adapter(source_code)
310
+ logger.info(f"MMS STT model updated to {source_code}")
311
+ model_status["stt"] = "loaded_mms"
312
+ else:
313
+ logger.warning(f"Language {source_code} not supported by MMS, using default")
314
+ model_status["stt"] = "loaded_mms_default"
315
+ except Exception as mms_error:
316
+ logger.error(f"Failed to load MMS STT model for {source_code}: {str(mms_error)}")
317
+ logger.info("Falling back to Whisper STT model...")
318
+ try:
319
+ stt_processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
320
+ stt_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
321
+ stt_model.to(device)
322
+ logger.info("Whisper STT model loaded successfully as fallback")
323
+ model_status["stt"] = "loaded_whisper"
324
+ except Exception as whisper_error:
325
+ logger.error(f"Failed to load Whisper STT model: {str(whisper_error)}")
326
+ model_status["stt"] = "failed"
327
+ error_message = f"STT model update failed: MMS error: {str(mms_error)}, Whisper error: {str(whisper_error)}"
328
+ return {"status": "failed", "error": error_message}
329
+ except Exception as e:
330
+ logger.error(f"Error updating STT model: {str(e)}")
331
+ model_status["stt"] = "failed"
332
+ error_message = f"STT model update failed: {str(e)}"
333
+ return {"status": "failed", "error": error_message}
334
 
335
+ # Update the TTS model based on the target language
336
+ try:
337
+ logger.info(f"Loading MMS-TTS model for {target_code}...")
338
+ from transformers import VitsModel, AutoTokenizer
339
+ tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{target_code}")
340
+ tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{target_code}")
341
+ tts_model.to(device)
342
+ current_tts_language = target_code
343
+ logger.info(f"TTS model updated to {target_code}")
344
+ model_status["tts"] = "loaded"
345
  except Exception as e:
346
+ logger.error(f"Failed to load TTS model for {target_code}: {str(e)}")
347
+ try:
348
+ logger.info("Falling back to MMS-TTS English model...")
349
+ tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
350
+ tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
351
+ tts_model.to(device)
352
+ current_tts_language = "eng"
353
+ logger.info("Fallback TTS model loaded successfully")
354
+ model_status["tts"] = "loaded (fallback)"
355
+ except Exception as e2:
356
+ logger.error(f"Failed to load fallback TTS model: {str(e2)}")
357
+ model_status["tts"] = "failed"
358
+ error_message = f"TTS model loading failed: {str(e)} (fallback also failed: {str(e2)})"
359
+ return {"status": "failed", "error": error_message}
360
+
361
+ logger.info(f"Updating languages: {source_lang} → {target_lang}")
362
+ return {"status": f"Languages updated to {source_lang} → {target_lang}"}
363
 
364
  @app.post("/translate-text")
365
  async def translate_text(text: str = Form(...), source_lang: str = Form(...), target_lang: str = Form(...)):
366
  """Endpoint to translate text and convert to speech"""
367
+ global mt_model, mt_tokenizer, tts_model, tts_tokenizer, current_tts_language
368
+
369
  if not text:
370
  raise HTTPException(status_code=400, detail="No text provided")
371
  if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
 
374
  logger.info(f"Translate-text requested: {text} from {source_lang} to {target_lang}")
375
  request_id = str(uuid.uuid4())
376
 
 
 
 
 
 
 
 
 
 
 
 
 
377
  # Translate the text
378
  source_code = LANGUAGE_MAPPING[source_lang]
379
  target_code = LANGUAGE_MAPPING[target_lang]
380
  translated_text = "Translation not available"
 
381
 
382
+ if model_status["mt"] == "loaded" and mt_model is not None and mt_tokenizer is not None:
383
+ try:
384
+ source_nllb_code = NLLB_LANGUAGE_CODES[source_code]
385
+ target_nllb_code = NLLB_LANGUAGE_CODES[target_code]
386
+ mt_tokenizer.src_lang = source_nllb_code
387
+ device = "cuda" if torch.cuda.is_available() else "cpu"
388
+ inputs = mt_tokenizer(text, return_tensors="pt").to(device)
389
+ with torch.no_grad():
390
+ generated_tokens = mt_model.generate(
391
+ **inputs,
392
+ forced_bos_token_id=mt_tokenizer.convert_tokens_to_ids(target_nllb_code),
393
+ max_length=448
394
+ )
395
+ translated_text = mt_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
396
+ logger.info(f"Translation completed: {translated_text}")
397
+ except Exception as e:
398
+ logger.error(f"Error during translation: {str(e)}")
399
+ translated_text = f"Translation failed: {str(e)}"
400
+ else:
401
+ logger.warning("MT model not loaded, skipping translation")
 
 
 
 
 
 
 
 
 
 
 
 
402
 
403
+ # Check for inappropriate content in the translated text
404
+ is_inappropriate = check_inappropriate_content(translated_text)
405
+ logger.info(f"Inappropriate content check: {is_inappropriate}")
406
+
407
+ # Update TTS model if the target language doesn't match the current TTS language
408
+ if current_tts_language != target_code:
409
+ try:
410
+ logger.info(f"Updating TTS model for {target_code}...")
411
+ from transformers import VitsModel, AutoTokenizer
412
+ tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{target_code}")
413
+ tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{target_code}")
414
+ tts_model.to(device)
415
+ current_tts_language = target_code
416
+ logger.info(f"TTS model updated to {target_code}")
417
+ model_status["tts"] = "loaded"
418
+ except Exception as e:
419
+ logger.error(f"Failed to load TTS model for {target_code}: {str(e)}")
420
+ try:
421
+ logger.info("Falling back to MMS-TTS English model...")
422
+ tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
423
+ tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
424
+ tts_model.to(device)
425
+ current_tts_language = "eng"
426
+ logger.info("Fallback TTS model loaded successfully")
427
+ model_status["tts"] = "loaded (fallback)"
428
+ except Exception as e2:
429
+ logger.error(f"Failed to load fallback TTS model: {str(e2)}")
430
+ model_status["tts"] = "failed"
431
 
432
  # Convert translated text to speech
433
  output_audio_url = None
434
+ if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
435
+ try:
436
+ inputs = tts_tokenizer(translated_text, return_tensors="pt").to(device)
437
+ with torch.no_grad():
438
+ output = tts_model(**inputs)
439
+ speech = output.waveform.cpu().numpy().squeeze()
440
+ speech = (speech * 32767).astype(np.int16)
441
+ sample_rate = tts_model.config.sampling_rate
442
 
443
+ # Save the audio as a WAV file
444
+ output_filename = f"{request_id}.wav"
445
+ output_path = os.path.join(AUDIO_DIR, output_filename)
446
+ save_pcm_to_wav(speech.tolist(), sample_rate, output_path)
447
+ logger.info(f"Saved synthesized audio to {output_path}")
448
 
449
+ # Generate a URL to the WAV file
450
+ output_audio_url = f"https://jerich-talklasapp.hf.space/audio_output/{output_filename}"
451
+ logger.info("TTS conversion completed")
452
+ except Exception as e:
453
+ logger.error(f"Error during TTS conversion: {str(e)}")
454
+ output_audio_url = None
455
 
456
  return {
457
  "request_id": request_id,
458
+ "status": "completed",
459
+ "message": "Translation and TTS completed (or partially completed).",
 
460
  "source_text": text,
461
  "translated_text": translated_text,
462
+ "is_inappropriate": is_inappropriate,
463
+ "output_audio": output_audio_url
464
  }
465
 
 
466
  @app.post("/translate-audio")
467
  async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form(...), target_lang: str = Form(...)):
468
  """Endpoint to transcribe, translate, and convert audio to speech"""
469
+ global stt_processor, stt_model, mt_model, mt_tokenizer, tts_model, tts_tokenizer, current_tts_language
470
+
471
  if not audio:
472
  raise HTTPException(status_code=400, detail="No audio file provided")
473
  if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
 
476
  logger.info(f"Translate-audio requested: {audio.filename} from {source_lang} to {target_lang}")
477
  request_id = str(uuid.uuid4())
478
 
479
+ # Check if STT model is loaded
480
+ if model_status["stt"] not in ["loaded_mms", "loaded_mms_default", "loaded_whisper"] or stt_processor is None or stt_model is None:
481
+ logger.warning("STT model not loaded, returning placeholder response")
482
+ return {
483
+ "request_id": request_id,
484
+ "status": "processing",
485
+ "message": "STT model not loaded yet. Please try again later.",
486
+ "source_text": "Transcription not available",
487
+ "translated_text": "Translation not available",
488
+ "is_inappropriate": False,
489
+ "output_audio": None
490
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
491
 
492
  # Save the uploaded audio to a temporary file
493
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
 
497
  transcription = "Transcription not available"
498
  translated_text = "Translation not available"
499
  output_audio_url = None
500
+ is_inappropriate = False
501
 
502
  try:
503
  # Step 1: Load and resample the audio using torchaudio
 
520
  "message": "No speech detected in the audio.",
521
  "source_text": "No speech detected",
522
  "translated_text": "No translation available",
523
+ "is_inappropriate": False,
524
+ "output_audio": None
525
  }
526
 
527
  # Step 3: Transcribe the audio (STT)
528
  device = "cuda" if torch.cuda.is_available() else "cpu"
529
  logger.info(f"Using device: {device}")
530
+ inputs = stt_processor(waveform.numpy(), sampling_rate=16000, return_tensors="pt").to(device)
531
+ logger.info("Audio processed, generating transcription...")
532
 
533
+ with torch.no_grad():
534
+ if model_status["stt"] == "loaded_whisper":
535
+ # Whisper model
536
+ generated_ids = stt_model.generate(**inputs, language="en")
 
 
 
 
 
 
537
  transcription = stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
538
+ else:
539
+ # MMS model
 
 
 
 
 
 
 
 
 
 
 
 
540
  logits = stt_model(**inputs).logits
541
  predicted_ids = torch.argmax(logits, dim=-1)
542
  transcription = stt_processor.batch_decode(predicted_ids)[0]
 
543
  logger.info(f"Transcription completed: {transcription}")
544
 
545
+ # Step 4: Translate the transcribed text (MT)
546
+ source_code = LANGUAGE_MAPPING[source_lang]
547
+ target_code = LANGUAGE_MAPPING[target_lang]
548
+
549
+ if model_status["mt"] == "loaded" and mt_model is not None and mt_tokenizer is not None:
550
+ try:
551
+ source_nllb_code = NLLB_LANGUAGE_CODES[source_code]
552
+ target_nllb_code = NLLB_LANGUAGE_CODES[target_code]
553
+ mt_tokenizer.src_lang = source_nllb_code
554
+ inputs = mt_tokenizer(transcription, return_tensors="pt").to(device)
555
+ with torch.no_grad():
556
+ generated_tokens = mt_model.generate(
557
+ **inputs,
558
+ forced_bos_token_id=mt_tokenizer.convert_tokens_to_ids(target_nllb_code),
559
+ max_length=448
560
+ )
561
+ translated_text = mt_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
562
+ logger.info(f"Translation completed: {translated_text}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
563
 
564
+ # Check for inappropriate content in the translated text
565
+ is_inappropriate = check_inappropriate_content(translated_text)
566
+ logger.info(f"Inappropriate content check: {is_inappropriate}")
567
+
568
+ except Exception as e:
569
+ logger.error(f"Error during translation: {str(e)}")
570
+ translated_text = f"Translation failed: {str(e)}"
571
+ else:
572
+ logger.warning("MT model not loaded, skipping translation")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
573
 
574
+ # Step 5: Update TTS model if the target language doesn't match the current TTS language
575
+ if current_tts_language != target_code:
576
+ try:
577
+ logger.info(f"Updating TTS model for {target_code}...")
578
+ from transformers import VitsModel, AutoTokenizer
579
+ tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{target_code}")
580
+ tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{target_code}")
581
+ tts_model.to(device)
582
+ current_tts_language = target_code
583
+ logger.info(f"TTS model updated to {target_code}")
584
+ model_status["tts"] = "loaded"
585
+ except Exception as e:
586
+ logger.error(f"Failed to load TTS model for {target_code}: {str(e)}")
587
+ try:
588
+ logger.info("Falling back to MMS-TTS English model...")
589
+ tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
590
+ tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
591
+ tts_model.to(device)
592
+ current_tts_language = "eng"
593
+ logger.info("Fallback TTS model loaded successfully")
594
+ model_status["tts"] = "loaded (fallback)"
595
+ except Exception as e2:
596
+ logger.error(f"Failed to load fallback TTS model: {str(e2)}")
597
+ model_status["tts"] = "failed"
598
+
599
+ # Step 6: Convert translated text to speech (TTS)
600
+ if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
601
+ try:
602
+ inputs = tts_tokenizer(translated_text, return_tensors="pt").to(device)
603
+ with torch.no_grad():
604
+ output = tts_model(**inputs)
605
+ speech = output.waveform.cpu().numpy().squeeze()
606
+ speech = (speech * 32767).astype(np.int16)
607
+ sample_rate = tts_model.config.sampling_rate
608
+
609
+ # Save the audio as a WAV file
610
+ output_filename = f"{request_id}.wav"
611
+ output_path = os.path.join(AUDIO_DIR, output_filename)
612
+ save_pcm_to_wav(speech.tolist(), sample_rate, output_path)
613
+ logger.info(f"Saved synthesized audio to {output_path}")
614
+
615
+ # Generate a URL to the WAV file
616
+ output_audio_url = f"https://jerich-talklasapp.hf.space/audio_output/{output_filename}"
617
+ logger.info("TTS conversion completed")
618
+ except Exception as e:
619
+ logger.error(f"Error during TTS conversion: {str(e)}")
620
+ output_audio_url = None
621
+ return {
622
  "request_id": request_id,
623
+ "status": "completed",
624
+ "message": "Transcription, translation, and TTS completed (or partially completed).",
 
625
  "source_text": transcription,
626
  "translated_text": translated_text,
627
+ "is_inappropriate": is_inappropriate,
628
+ "output_audio": output_audio_url
629
  }
630
  except Exception as e:
631
  logger.error(f"Error during processing: {str(e)}")
 
635
  "message": f"Processing failed: {str(e)}",
636
  "source_text": transcription,
637
  "translated_text": translated_text,
638
+ "is_inappropriate": is_inappropriate,
639
+ "output_audio": output_audio_url
640
  }
641
  finally:
642
  logger.info(f"Cleaning up temporary file: {temp_path}")
643
+ os.unlink(temp_path)
 
 
 
 
644
 
645
+ @app.post("/synthesize-speech")
646
+ async def synthesize_speech(text: str = Form(...), target_lang: str = Form(...)):
647
+ """Endpoint to generate synthesized speech for the given text in the target language"""
648
+ global tts_model, tts_tokenizer, current_tts_language
649
+
650
+ if not text:
651
+ raise HTTPException(status_code=400, detail="No text provided")
652
+ if target_lang not in LANGUAGE_MAPPING:
653
+ raise HTTPException(status_code=400, detail="Invalid language selected")
654
+
655
+ logger.info(f"Synthesize-speech requested: '{text}' in {target_lang}")
656
+ request_id = str(uuid.uuid4())
657
+
658
+ # Check if TTS model is loaded
659
+ if not model_status["tts"].startswith("loaded") or tts_model is None or tts_tokenizer is None:
660
+ logger.warning("TTS model not loaded, returning error response")
661
+ return {
662
+ "request_id": request_id,
663
+ "status": "processing",
664
+ "message": "TTS model not loaded yet. Please try again later.",
665
+ "output_audio": None
666
+ }
667
+
668
+ target_code = LANGUAGE_MAPPING[target_lang]
669
+ output_audio_url = None
670
+
671
+ try:
672
+ # Update TTS model if the target language doesn't match the current TTS language
673
+ device = "cuda" if torch.cuda.is_available() else "cpu"
674
+ if current_tts_language != target_code:
675
+ try:
676
+ logger.info(f"Updating TTS model for {target_code}...")
677
+ from transformers import VitsModel, AutoTokenizer
678
+ tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{target_code}")
679
+ tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{target_code}")
680
+ tts_model.to(device)
681
+ current_tts_language = target_code
682
+ logger.info(f"TTS model updated to {target_code}")
683
+ model_status["tts"] = "loaded"
684
+ except Exception as e:
685
+ logger.error(f"Failed to load TTS model for {target_code}: {str(e)}")
686
+ try:
687
+ logger.info("Falling back to MMS-TTS English model...")
688
+ tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
689
+ tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
690
+ tts_model.to(device)
691
+ current_tts_language = "eng"
692
+ logger.info("Fallback TTS model loaded successfully")
693
+ model_status["tts"] = "loaded (fallback)"
694
+ except Exception as e2:
695
+ logger.error(f"Failed to load fallback TTS model: {str(e2)}")
696
+ model_status["tts"] = "failed"
697
+ error_message = f"TTS model loading failed: {str(e)} (fallback also failed: {str(e2)})"
698
+ return {
699
+ "request_id": request_id,
700
+ "status": "failed",
701
+ "message": error_message,
702
+ "output_audio": None
703
+ }
704
+
705
+ # Check for inappropriate content in the input text
706
+ is_inappropriate = check_inappropriate_content(text)
707
+ logger.info(f"Inappropriate content check: {is_inappropriate}")
708
+
709
+ # Generate speech from text
710
+ inputs = tts_tokenizer(text, return_tensors="pt").to(device)
711
+ logger.info("Generating speech...")
712
+
713
+ with torch.no_grad():
714
+ output = tts_model(**inputs)
715
+
716
+ speech = output.waveform.cpu().numpy().squeeze()
717
+ speech = (speech * 32767).astype(np.int16)
718
+ sample_rate = tts_model.config.sampling_rate
719
+ # Save the audio as a WAV file
720
+ output_filename = f"{request_id}.wav"
721
+ output_path = os.path.join(AUDIO_DIR, output_filename)
722
+ save_pcm_to_wav(speech.tolist(), sample_rate, output_path)
723
+ logger.info(f"Saved synthesized audio to {output_path}")
724
 
725
+ # Generate a URL to the WAV file
726
+ output_audio_url = f"https://jerich-talklasapp.hf.space/audio_output/{output_filename}"
727
+ logger.info("TTS conversion completed")
728
+
729
+ return {
730
+ "request_id": request_id,
731
+ "status": "completed",
732
+ "message": "Text-to-speech conversion completed successfully.",
733
+ "text": text,
734
+ "is_inappropriate": is_inappropriate,
735
+ "output_audio": output_audio_url
736
+ }
737
+ except Exception as e:
738
+ logger.error(f"Error during speech synthesis: {str(e)}")
739
+ return {
740
+ "request_id": request_id,
741
+ "status": "failed",
742
+ "message": f"Speech synthesis failed: {str(e)}",
743
+ "output_audio": None
744
+ }
745
 
746
  if __name__ == "__main__":
747
  import uvicorn