Jerich commited on
Commit
798e9af
·
verified ·
1 Parent(s): 0f3ec29

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +456 -350
app.py CHANGED
@@ -31,16 +31,20 @@ os.makedirs(AUDIO_DIR, exist_ok=True)
31
  app.mount("/audio_output", StaticFiles(directory=AUDIO_DIR), name="audio_output")
32
 
33
  # Global variables to track application state
34
- models_loaded = False
35
- loading_in_progress = False
36
- loading_thread = None
37
- model_status = {
38
- "stt_mms": "not_loaded",
39
- "stt_whisper_small": "not_loaded",
40
- "mt": "not_loaded",
41
- "tts": {} # Will store status for each language
 
 
 
 
 
42
  }
43
- error_message = None
44
 
45
  # Define the valid languages and mappings
46
  LANGUAGE_MAPPING = {
@@ -61,30 +65,25 @@ NLLB_LANGUAGE_CODES = {
61
  "pag": "pag_Latn"
62
  }
63
 
64
- # Model dictionaries for different languages
65
- stt_models = {
66
- "mms": None,
67
- "mms_processor": None,
68
- "whisper_small": None,
69
- "whisper_small_processor": None
70
- }
71
 
72
- mt_model = None
73
- mt_tokenizer = None
74
 
75
- tts_models = {} # Will store models for each language
76
- tts_tokenizers = {} # Will store tokenizers for each language
 
 
 
 
 
 
 
 
77
 
78
- # List of inappropriate words/phrases for content filtering
79
- INAPPROPRIATE_WORDS = [
80
- "fuck", "shit", "asshole", "bitch", "dick", "pussy", "cunt",
81
- "whore", "slut", "bastard", "damn", "hell", "piss", "nigger",
82
- "faggot", "retard", "crap", "porn", "sex", "penis", "vagina",
83
- # Tagalog inappropriate words
84
- "puta", "putangina", "gago", "bobo", "tanga", "tarantado",
85
- "inutil", "ulol", "kantot", "jakol", "tite", "pekpek",
86
- # Add more as needed
87
- ]
88
 
89
  # Function to save PCM data as a WAV file
90
  def save_pcm_to_wav(pcm_data: list, sample_rate: int, output_path: str):
@@ -99,6 +98,7 @@ def save_pcm_to_wav(pcm_data: list, sample_rate: int, output_path: str):
99
  # Write the 16-bit PCM data as bytes (little-endian)
100
  wav_file.writeframes(pcm_array.tobytes())
101
 
 
102
  # Function to detect speech using an energy-based approach
103
  def detect_speech(waveform: torch.Tensor, sample_rate: int, threshold: float = 0.01, min_speech_duration: float = 0.5) -> bool:
104
  """
@@ -123,52 +123,6 @@ def detect_speech(waveform: torch.Tensor, sample_rate: int, threshold: float = 0
123
  # For now, we assume if RMS is above threshold, there is speech
124
  return True
125
 
126
- # Function to check for inappropriate content
127
- def check_inappropriate_content(text: str) -> bool:
128
- """
129
- Checks if the text contains inappropriate content.
130
- Returns True if inappropriate content is detected, False otherwise.
131
- """
132
- # Convert text to lowercase for case-insensitive matching
133
- text_lower = text.lower()
134
-
135
- # Check if any inappropriate word is in the text
136
- for word in INAPPROPRIATE_WORDS:
137
- # Use word boundary regex to match whole words only
138
- pattern = r'\b' + re.escape(word) + r'\b'
139
- if re.search(pattern, text_lower):
140
- logger.warning(f"Inappropriate content detected: '{word}'")
141
- return True
142
-
143
- return False
144
-
145
- # Function to perform text-to-speech conversion
146
- def text_to_speech(text: str, language_code: str) -> Tuple[Optional[np.ndarray], Optional[int], Optional[str]]:
147
- """
148
- Convert text to speech using the appropriate TTS model.
149
- Returns the speech waveform, sample rate, and any error message.
150
- """
151
- if language_code not in tts_models or tts_models[language_code] is None:
152
- error_msg = f"TTS model for {language_code} not loaded"
153
- logger.error(error_msg)
154
- return None, None, error_msg
155
-
156
- try:
157
- device = "cuda" if torch.cuda.is_available() else "cpu"
158
- inputs = tts_tokenizers[language_code](text, return_tensors="pt").to(device)
159
-
160
- with torch.no_grad():
161
- output = tts_models[language_code](**inputs)
162
-
163
- speech = output.waveform.cpu().numpy().squeeze()
164
- speech = (speech * 32767).astype(np.int16)
165
- sample_rate = tts_models[language_code].config.sampling_rate
166
-
167
- return speech, sample_rate, None
168
- except Exception as e:
169
- error_msg = f"Error during TTS conversion: {str(e)}"
170
- logger.error(error_msg)
171
- return None, None, error_msg
172
 
173
  # Function to clean up old audio files
174
  def cleanup_old_audio_files():
@@ -185,142 +139,157 @@ def cleanup_old_audio_files():
185
  except Exception as e:
186
  logger.error(f"Error deleting file {file_path}: {str(e)}")
187
 
 
188
  # Background task to periodically clean up audio files
189
  def schedule_cleanup():
190
  while True:
191
  cleanup_old_audio_files()
192
  time.sleep(300) # Run every 5 minutes (300 seconds)
193
 
194
- # Function to load models in background
195
- def load_models_task():
196
- global models_loaded, loading_in_progress, model_status, error_message
197
- global stt_models, mt_model, mt_tokenizer, tts_models, tts_tokenizers
 
 
 
 
 
 
198
 
199
  try:
200
- loading_in_progress = True
 
 
 
201
  device = "cuda" if torch.cuda.is_available() else "cpu"
202
 
203
- # Load STT models (both MMS and Whisper)
204
- logger.info("Starting to load STT models...")
 
205
 
206
- # Load MMS STT model
207
- try:
208
- logger.info("Loading MMS STT model...")
209
- model_status["stt_mms"] = "loading"
210
- from transformers import AutoProcessor, AutoModelForCTC
211
-
212
- stt_models["mms_processor"] = AutoProcessor.from_pretrained("facebook/mms-1b-all")
213
- stt_models["mms"] = AutoModelForCTC.from_pretrained("facebook/mms-1b-all")
214
- stt_models["mms"].to(device)
215
- logger.info("MMS STT model loaded successfully")
216
- model_status["stt_mms"] = "loaded"
217
- except Exception as mms_error:
218
- logger.error(f"Failed to load MMS STT model: {str(mms_error)}")
219
- model_status["stt_mms"] = "failed"
220
- error_message = f"MMS STT model loading failed: {str(mms_error)}"
 
 
 
 
 
 
 
 
221
 
222
- # Load Whisper Small STT model
223
- try:
224
- logger.info("Loading Whisper Small STT model...")
225
- model_status["stt_whisper_small"] = "loading"
226
- from transformers import WhisperProcessor, WhisperForConditionalGeneration
227
-
228
- stt_models["whisper_small_processor"] = WhisperProcessor.from_pretrained("openai/whisper-small")
229
- stt_models["whisper_small"] = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
230
- stt_models["whisper_small"].to(device)
231
- logger.info("Whisper Small STT model loaded successfully")
232
- model_status["stt_whisper_small"] = "loaded"
233
- except Exception as whisper_error:
234
- logger.error(f"Failed to load Whisper Small STT model: {str(whisper_error)}")
235
- model_status["stt_whisper_small"] = "failed"
236
- error_message = f"Whisper Small STT model loading failed: {str(whisper_error)}"
237
-
238
- # Load MT model
239
- logger.info("Starting to load MT model...")
240
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
241
 
242
- try:
243
- logger.info("Loading NLLB-200-distilled-600M model...")
244
- model_status["mt"] = "loading"
245
- mt_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
246
- mt_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
247
- mt_model.to(device)
248
- logger.info("MT model loaded successfully")
249
- model_status["mt"] = "loaded"
250
- except Exception as e:
251
- logger.error(f"Failed to load MT model: {str(e)}")
252
- model_status["mt"] = "failed"
253
- error_message = f"MT model loading failed: {str(e)}"
254
 
255
- # Load TTS models for all supported languages
256
- logger.info("Starting to load TTS models for all languages...")
257
- from transformers import VitsModel, AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
- for lang_name, lang_code in LANGUAGE_MAPPING.items():
260
- try:
261
- logger.info(f"Loading MMS-TTS model for {lang_name} ({lang_code})...")
262
- model_status["tts"][lang_code] = "loading"
263
-
264
- # Load the model and tokenizer
265
- tts_models[lang_code] = VitsModel.from_pretrained(f"facebook/mms-tts-{lang_code}")
266
- tts_tokenizers[lang_code] = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{lang_code}")
267
-
268
- # Move to GPU if available
269
- tts_models[lang_code].to(device)
270
-
271
- logger.info(f"TTS model for {lang_name} loaded successfully")
272
- model_status["tts"][lang_code] = "loaded"
273
- except Exception as e:
274
- logger.error(f"Failed to load TTS model for {lang_name}: {str(e)}")
275
- model_status["tts"][lang_code] = "failed"
276
-
277
- # Try to load English as fallback if this is not English
278
- if lang_code != "eng":
279
- try:
280
- logger.info(f"Trying to load English TTS model as fallback for {lang_name}...")
281
- # Only load English model once if not already loaded
282
- if "eng" not in tts_models or tts_models["eng"] is None:
283
- tts_models["eng"] = VitsModel.from_pretrained("facebook/mms-tts-eng")
284
- tts_tokenizers["eng"] = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
285
- tts_models["eng"].to(device)
286
- model_status["tts"]["eng"] = "loaded"
287
-
288
- # Point this language to use English model
289
- tts_models[lang_code] = tts_models["eng"]
290
- tts_tokenizers[lang_code] = tts_tokenizers["eng"]
291
- model_status["tts"][lang_code] = "loaded (fallback to eng)"
292
- except Exception as e2:
293
- logger.error(f"Failed to load English fallback TTS model: {str(e2)}")
294
- model_status["tts"][lang_code] = "failed (with fallback)"
295
 
296
- # Set models_loaded flag based on which critical models are loaded
297
- # Consider the system usable if we have at least one STT model, the MT model, and at least one TTS model
298
- stt_loaded = model_status["stt_mms"] == "loaded" or model_status["stt_whisper_small"] == "loaded"
299
- mt_loaded = model_status["mt"] == "loaded"
300
- any_tts_loaded = any(status == "loaded" or status.startswith("loaded (fallback")
301
- for status in model_status["tts"].values())
302
 
303
- models_loaded = stt_loaded and mt_loaded and any_tts_loaded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
- if models_loaded:
306
- logger.info("Critical models loaded successfully - system is ready")
307
- else:
308
- logger.warning("Some critical models failed to load - system may have limited functionality")
309
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  except Exception as e:
311
- error_message = str(e)
312
- logger.error(f"Error in model loading task: {str(e)}")
 
313
  finally:
314
- loading_in_progress = False
315
 
316
- # Start loading models in background
317
- def start_model_loading():
318
- global loading_thread, loading_in_progress
319
- if not loading_in_progress:
320
- loading_in_progress = True
321
- loading_thread = threading.Thread(target=load_models_task)
322
- loading_thread.daemon = True
323
- loading_thread.start()
324
 
325
  # Start the background cleanup task
326
  def start_cleanup_task():
@@ -328,88 +297,130 @@ def start_cleanup_task():
328
  cleanup_thread.daemon = True
329
  cleanup_thread.start()
330
 
 
331
  # Start the background processes when the app starts
332
  @app.on_event("startup")
333
  async def startup_event():
334
  logger.info("Application starting up...")
335
- start_model_loading()
336
  start_cleanup_task()
337
 
 
338
  @app.get("/")
339
  async def root():
340
  """Root endpoint for default health check"""
341
  logger.info("Root endpoint requested")
342
  return {"status": "healthy"}
343
 
 
344
  @app.get("/health")
345
  async def health_check():
346
  """Health check endpoint that always returns successfully"""
347
- global models_loaded, loading_in_progress, model_status, error_message
348
  logger.info("Health check requested")
349
  return {
350
  "status": "healthy",
351
- "models_loaded": models_loaded,
352
- "loading_in_progress": loading_in_progress,
353
- "model_status": model_status,
354
- "error": error_message
 
 
 
355
  }
356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  @app.post("/synthesize-speech")
358
  async def synthesize_speech(text: str = Form(...), language: str = Form(...)):
359
  """Endpoint to synthesize speech from text without translation"""
360
  if language not in LANGUAGE_MAPPING:
361
  raise HTTPException(status_code=400, detail="Invalid language selected")
362
 
363
- logger.info(f"Speech synthesis requested for text in {language}")
364
- request_id = str(uuid.uuid4())
365
  language_code = LANGUAGE_MAPPING[language]
 
366
 
367
- # Check if the TTS model is loaded
368
- if language_code not in tts_models or tts_models[language_code] is None:
369
  return {
370
  "request_id": request_id,
371
- "status": "failed",
372
- "message": f"TTS model for {language} not loaded yet",
373
- "output_audio": None,
374
- "is_inappropriate": False
375
  }
376
 
377
- # Check for inappropriate content
378
- is_inappropriate = check_inappropriate_content(text)
379
-
380
- # Generate speech
381
- speech, sample_rate, error = text_to_speech(text, language_code)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
 
383
- if error:
 
384
  return {
385
  "request_id": request_id,
386
  "status": "failed",
387
- "message": error,
388
- "output_audio": None,
389
- "is_inappropriate": is_inappropriate
390
  }
391
-
392
- # Save the synthesized audio
393
- output_filename = f"{request_id}.wav"
394
- output_path = os.path.join(AUDIO_DIR, output_filename)
395
- save_pcm_to_wav(speech.tolist(), sample_rate, output_path)
396
-
397
- # Generate URL to the WAV file
398
- output_audio_url = f"https://jerich-talklasapp2.hf.space/audio_output/{output_filename}"
399
-
400
- return {
401
- "request_id": request_id,
402
- "status": "completed",
403
- "message": "Speech synthesis completed",
404
- "output_audio": output_audio_url,
405
- "is_inappropriate": is_inappropriate
406
- }
407
 
408
  @app.post("/translate-text")
409
  async def translate_text(text: str = Form(...), source_lang: str = Form(...), target_lang: str = Form(...)):
410
  """Endpoint to translate text and convert to speech"""
411
- global mt_model, mt_tokenizer
412
-
413
  if not text:
414
  raise HTTPException(status_code=400, detail="No text provided")
415
  if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
@@ -418,64 +429,107 @@ async def translate_text(text: str = Form(...), source_lang: str = Form(...), ta
418
  logger.info(f"Translate-text requested: {text} from {source_lang} to {target_lang}")
419
  request_id = str(uuid.uuid4())
420
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  # Translate the text
422
  source_code = LANGUAGE_MAPPING[source_lang]
423
  target_code = LANGUAGE_MAPPING[target_lang]
424
  translated_text = "Translation not available"
 
425
 
426
- if model_status["mt"] == "loaded" and mt_model is not None and mt_tokenizer is not None:
427
- try:
428
- source_nllb_code = NLLB_LANGUAGE_CODES[source_code]
429
- target_nllb_code = NLLB_LANGUAGE_CODES[target_code]
430
- mt_tokenizer.src_lang = source_nllb_code
431
- device = "cuda" if torch.cuda.is_available() else "cpu"
432
- inputs = mt_tokenizer(text, return_tensors="pt").to(device)
433
- with torch.no_grad():
434
- generated_tokens = mt_model.generate(
435
- **inputs,
436
- forced_bos_token_id=mt_tokenizer.convert_tokens_to_ids(target_nllb_code),
437
- max_length=448
438
- )
439
- translated_text = mt_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
440
- logger.info(f"Translation completed: {translated_text}")
441
- except Exception as e:
442
- logger.error(f"Error during translation: {str(e)}")
443
- translated_text = f"Translation failed: {str(e)}"
444
- else:
445
- logger.warning("MT model not loaded, skipping translation")
446
-
447
- # Check for inappropriate content in the translation
448
- is_inappropriate = check_inappropriate_content(translated_text)
449
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
  # Convert translated text to speech
451
- speech, sample_rate, error = text_to_speech(translated_text, target_code)
452
-
453
  output_audio_url = None
454
- if speech is not None and sample_rate is not None:
 
 
 
 
 
 
 
455
  # Save the audio as a WAV file
456
  output_filename = f"{request_id}.wav"
457
  output_path = os.path.join(AUDIO_DIR, output_filename)
458
  save_pcm_to_wav(speech.tolist(), sample_rate, output_path)
459
-
 
460
  # Generate a URL to the WAV file
461
- output_audio_url = f"https://jerich-talklasapp2.hf.space/audio_output/{output_filename}"
462
  logger.info("TTS conversion completed")
 
 
 
463
 
464
  return {
465
  "request_id": request_id,
466
- "status": "completed",
467
- "message": "Translation and TTS completed (or partially completed).",
 
468
  "source_text": text,
469
  "translated_text": translated_text,
470
  "output_audio": output_audio_url,
471
- "is_inappropriate": is_inappropriate
472
  }
473
 
 
474
  @app.post("/translate-audio")
475
  async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form(...), target_lang: str = Form(...)):
476
  """Endpoint to transcribe, translate, and convert audio to speech"""
477
- global stt_models, mt_model, mt_tokenizer
478
-
479
  if not audio:
480
  raise HTTPException(status_code=400, detail="No audio file provided")
481
  if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
@@ -484,38 +538,35 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
484
  logger.info(f"Translate-audio requested: {audio.filename} from {source_lang} to {target_lang}")
485
  request_id = str(uuid.uuid4())
486
 
487
- # Check if appropriate STT model is loaded
488
  source_code = LANGUAGE_MAPPING[source_lang]
489
- use_whisper = source_code in ["eng", "tgl"] # Use Whisper for English or Tagalog
 
 
 
490
 
491
- if use_whisper and (model_status["stt_whisper_small"] != "loaded" or stt_models["whisper_small"] is None):
492
- logger.warning("Whisper Small STT model not loaded for English/Tagalog, checking MMS")
493
- if model_status["stt_mms"] != "loaded" or stt_models["mms"] is None:
494
- logger.warning("MMS STT model not loaded either, returning placeholder response")
495
  return {
496
  "request_id": request_id,
497
- "status": "processing",
498
- "message": "STT models not loaded yet. Please try again later.",
499
- "source_text": "Transcription not available",
500
- "translated_text": "Translation not available",
501
  "output_audio": None,
502
- "is_inappropriate": False
503
  }
504
- use_whisper = False # Fall back to MMS
505
- elif not use_whisper and (model_status["stt_mms"] != "loaded" or stt_models["mms"] is None):
506
- logger.warning("MMS STT model not loaded for non-English/Tagalog, checking Whisper")
507
- if model_status["stt_whisper_small"] != "loaded" or stt_models["whisper_small"] is None:
508
- logger.warning("Whisper Small STT model not loaded either, returning placeholder response")
509
  return {
510
  "request_id": request_id,
511
- "status": "processing",
512
- "message": "STT models not loaded yet. Please try again later.",
513
- "source_text": "Transcription not available",
514
- "translated_text": "Translation not available",
515
  "output_audio": None,
516
- "is_inappropriate": False
517
  }
518
- use_whisper = True # Fall back to Whisper
519
 
520
  # Save the uploaded audio to a temporary file
521
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
@@ -525,7 +576,7 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
525
  transcription = "Transcription not available"
526
  translated_text = "Translation not available"
527
  output_audio_url = None
528
- is_inappropriate = False
529
 
530
  try:
531
  # Step 1: Load and resample the audio using torchaudio
@@ -549,94 +600,132 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
549
  "source_text": "No speech detected",
550
  "translated_text": "No translation available",
551
  "output_audio": None,
552
- "is_inappropriate": False
553
  }
554
 
555
  # Step 3: Transcribe the audio (STT)
556
  device = "cuda" if torch.cuda.is_available() else "cpu"
557
- logger.info(f"Using device: {device} with {'Whisper' if use_whisper else 'MMS'} model")
558
 
559
  if use_whisper:
560
- # Use Whisper Small for English or Tagalog
561
- logger.info("Using Whisper Small for transcription")
562
- processor = stt_models["whisper_small_processor"]
563
- model = stt_models["whisper_small"]
 
 
564
 
565
- inputs = processor(waveform.numpy()[0], sampling_rate=16000, return_tensors="pt").to(device)
566
  with torch.no_grad():
567
- # Use the language code for forced decoding if source is English or Tagalog
568
- language = "en" if source_code == "eng" else "tl" if source_code == "tgl" else None
569
- generated_ids = model.generate(
570
- **inputs,
571
- language=language,
572
- task="transcribe"
573
- )
574
- transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
575
  else:
576
- # Use MMS for other languages
577
- logger.info("Using MMS for transcription")
578
- processor = stt_models["mms_processor"]
579
- model = stt_models["mms"]
 
 
 
 
580
 
581
- if source_code in processor.tokenizer.vocab.keys():
582
- processor.tokenizer.set_target_lang(source_code)
583
- model.load_adapter(source_code)
584
 
585
- inputs = processor(waveform.numpy(), sampling_rate=16000, return_tensors="pt").to(device)
586
  with torch.no_grad():
587
- logits = model(**inputs).logits
588
  predicted_ids = torch.argmax(logits, dim=-1)
589
- transcription = processor.batch_decode(predicted_ids)[0]
590
-
591
  logger.info(f"Transcription completed: {transcription}")
592
 
593
- # Step 4: Translate the transcribed text (MT)
594
- target_code = LANGUAGE_MAPPING[target_lang]
595
-
596
- if model_status["mt"] == "loaded" and mt_model is not None and mt_tokenizer is not None:
597
- try:
598
- source_nllb_code = NLLB_LANGUAGE_CODES[source_code]
599
- target_nllb_code = NLLB_LANGUAGE_CODES[target_code]
600
- mt_tokenizer.src_lang = source_nllb_code
601
- inputs = mt_tokenizer(transcription, return_tensors="pt").to(device)
602
- with torch.no_grad():
603
- generated_tokens = mt_model.generate(
604
- **inputs,
605
- forced_bos_token_id=mt_tokenizer.convert_tokens_to_ids(target_nllb_code),
606
- max_length=448
607
- )
608
- translated_text = mt_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
609
- logger.info(f"Translation completed: {translated_text}")
610
- except Exception as e:
611
- logger.error(f"Error during translation: {str(e)}")
612
- translated_text = f"Translation failed: {str(e)}"
613
- else:
614
- logger.warning("MT model not loaded, skipping translation")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
615
 
616
- # Step 5: Check for inappropriate content in the translation
617
- is_inappropriate = check_inappropriate_content(translated_text)
 
 
 
 
 
 
 
 
 
618
 
619
- # Step 6: Convert translated text to speech (TTS)
620
- speech, sample_rate, error = text_to_speech(translated_text, target_code)
621
-
622
- if speech is not None and sample_rate is not None:
 
 
 
 
623
  # Save the audio as a WAV file
624
  output_filename = f"{request_id}.wav"
625
  output_path = os.path.join(AUDIO_DIR, output_filename)
626
  save_pcm_to_wav(speech.tolist(), sample_rate, output_path)
627
-
 
628
  # Generate a URL to the WAV file
629
- output_audio_url = f"https://jerich-talklasapp2.hf.space/audio_output/{output_filename}"
630
  logger.info("TTS conversion completed")
 
 
 
631
 
632
  return {
633
  "request_id": request_id,
634
- "status": "completed",
635
- "message": "Transcription, translation, and TTS completed (or partially completed).",
 
636
  "source_text": transcription,
637
  "translated_text": translated_text,
638
  "output_audio": output_audio_url,
639
- "is_inappropriate": is_inappropriate
640
  }
641
  except Exception as e:
642
  logger.error(f"Error during processing: {str(e)}")
@@ -647,11 +736,28 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
647
  "source_text": transcription,
648
  "translated_text": translated_text,
649
  "output_audio": output_audio_url,
650
- "is_inappropriate": is_inappropriate
651
  }
652
  finally:
653
  logger.info(f"Cleaning up temporary file: {temp_path}")
654
- os.unlink(temp_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
655
 
656
  if __name__ == "__main__":
657
  import uvicorn
 
31
  app.mount("/audio_output", StaticFiles(directory=AUDIO_DIR), name="audio_output")
32
 
33
  # Global variables to track application state
34
+ model_cache = {
35
+ "stt_whisper": {"model": None, "processor": None, "status": "not_loaded"},
36
+ "stt_mms": {"model": None, "processor": None, "status": "not_loaded"},
37
+ "mt": {"model": None, "tokenizer": None, "status": "not_loaded"},
38
+ "tts": {"model": None, "tokenizer": None, "status": "not_loaded", "language": None}
39
+ }
40
+
41
+ # Track loading status
42
+ loading_locks = {
43
+ "stt_whisper": threading.Lock(),
44
+ "stt_mms": threading.Lock(),
45
+ "mt": threading.Lock(),
46
+ "tts": threading.Lock()
47
  }
 
48
 
49
  # Define the valid languages and mappings
50
  LANGUAGE_MAPPING = {
 
65
  "pag": "pag_Latn"
66
  }
67
 
68
+ # Inappropriate words list - this is a basic implementation
69
+ # In a production environment, you would use a more comprehensive solution
70
+ INAPPROPRIATE_WORDS = [
71
+ "putang", "tang ina", "gago", "puta", "bobo", "ulol", "pakyu", "tae",
72
+ "obscenity", "profanity", "explicit", "nsfw", "offensive"
73
+ ]
 
74
 
 
 
75
 
76
+ # Function to detect inappropriate content
77
+ def detect_inappropriate_content(text: str) -> bool:
78
+ """
79
+ Checks if the text contains any inappropriate words
80
+ """
81
+ text_lower = text.lower()
82
+ for word in INAPPROPRIATE_WORDS:
83
+ if word in text_lower:
84
+ return True
85
+ return False
86
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  # Function to save PCM data as a WAV file
89
  def save_pcm_to_wav(pcm_data: list, sample_rate: int, output_path: str):
 
98
  # Write the 16-bit PCM data as bytes (little-endian)
99
  wav_file.writeframes(pcm_array.tobytes())
100
 
101
+
102
  # Function to detect speech using an energy-based approach
103
  def detect_speech(waveform: torch.Tensor, sample_rate: int, threshold: float = 0.01, min_speech_duration: float = 0.5) -> bool:
104
  """
 
123
  # For now, we assume if RMS is above threshold, there is speech
124
  return True
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  # Function to clean up old audio files
128
  def cleanup_old_audio_files():
 
139
  except Exception as e:
140
  logger.error(f"Error deleting file {file_path}: {str(e)}")
141
 
142
+
143
  # Background task to periodically clean up audio files
144
  def schedule_cleanup():
145
  while True:
146
  cleanup_old_audio_files()
147
  time.sleep(300) # Run every 5 minutes (300 seconds)
148
 
149
+
150
+ # Function to load the Whisper STT model on demand
151
+ def load_whisper_model():
152
+ if model_cache["stt_whisper"]["status"] == "loaded":
153
+ return True
154
+
155
+ # Use lock to prevent multiple threads from loading the model simultaneously
156
+ if not loading_locks["stt_whisper"].acquire(blocking=False):
157
+ logger.info("Whisper model loading already in progress")
158
+ return False
159
 
160
  try:
161
+ logger.info("Loading Whisper small model...")
162
+ model_cache["stt_whisper"]["status"] = "loading"
163
+
164
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
165
  device = "cuda" if torch.cuda.is_available() else "cpu"
166
 
167
+ model_cache["stt_whisper"]["processor"] = WhisperProcessor.from_pretrained("openai/whisper-small")
168
+ model_cache["stt_whisper"]["model"] = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
169
+ model_cache["stt_whisper"]["model"].to(device)
170
 
171
+ model_cache["stt_whisper"]["status"] = "loaded"
172
+ logger.info("Whisper small model loaded successfully")
173
+ return True
174
+ except Exception as e:
175
+ model_cache["stt_whisper"]["status"] = "failed"
176
+ logger.error(f"Failed to load Whisper model: {str(e)}")
177
+ return False
178
+ finally:
179
+ loading_locks["stt_whisper"].release()
180
+
181
+
182
+ # Function to load the MMS STT model on demand
183
+ def load_mms_stt_model():
184
+ if model_cache["stt_mms"]["status"] == "loaded":
185
+ return True
186
+
187
+ if not loading_locks["stt_mms"].acquire(blocking=False):
188
+ logger.info("MMS STT model loading already in progress")
189
+ return False
190
+
191
+ try:
192
+ logger.info("Loading MMS STT model...")
193
+ model_cache["stt_mms"]["status"] = "loading"
194
 
195
+ from transformers import AutoProcessor, AutoModelForCTC
196
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
+ model_cache["stt_mms"]["processor"] = AutoProcessor.from_pretrained("facebook/mms-1b-all")
199
+ model_cache["stt_mms"]["model"] = AutoModelForCTC.from_pretrained("facebook/mms-1b-all")
200
+ model_cache["stt_mms"]["model"].to(device)
 
 
 
 
 
 
 
 
 
201
 
202
+ model_cache["stt_mms"]["status"] = "loaded"
203
+ logger.info("MMS STT model loaded successfully")
204
+ return True
205
+ except Exception as e:
206
+ model_cache["stt_mms"]["status"] = "failed"
207
+ logger.error(f"Failed to load MMS STT model: {str(e)}")
208
+ return False
209
+ finally:
210
+ loading_locks["stt_mms"].release()
211
+
212
+
213
+ # Function to load the MT model on demand
214
+ def load_mt_model():
215
+ if model_cache["mt"]["status"] == "loaded":
216
+ return True
217
+
218
+ if not loading_locks["mt"].acquire(blocking=False):
219
+ logger.info("MT model loading already in progress")
220
+ return False
221
+
222
+ try:
223
+ logger.info("Loading NLLB-200-distilled-600M model...")
224
+ model_cache["mt"]["status"] = "loading"
225
 
226
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
227
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
+ model_cache["mt"]["model"] = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
230
+ model_cache["mt"]["tokenizer"] = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
231
+ model_cache["mt"]["model"].to(device)
 
 
 
232
 
233
+ model_cache["mt"]["status"] = "loaded"
234
+ logger.info("MT model loaded successfully")
235
+ return True
236
+ except Exception as e:
237
+ model_cache["mt"]["status"] = "failed"
238
+ logger.error(f"Failed to load MT model: {str(e)}")
239
+ return False
240
+ finally:
241
+ loading_locks["mt"].release()
242
+
243
+
244
+ # Function to load the TTS model for a specific language on demand
245
+ def load_tts_model(language_code: str):
246
+ # If the model is already loaded for this language, return immediately
247
+ if (model_cache["tts"]["status"] == "loaded" and
248
+ model_cache["tts"]["language"] == language_code):
249
+ return True
250
+
251
+ if not loading_locks["tts"].acquire(blocking=False):
252
+ logger.info("TTS model loading already in progress")
253
+ return False
254
+
255
+ try:
256
+ logger.info(f"Loading MMS-TTS model for {language_code}...")
257
+ model_cache["tts"]["status"] = "loading"
258
 
259
+ from transformers import VitsModel, AutoTokenizer
260
+ device = "cuda" if torch.cuda.is_available() else "cpu"
261
+
262
+ try:
263
+ model_cache["tts"]["model"] = VitsModel.from_pretrained(f"facebook/mms-tts-{language_code}")
264
+ model_cache["tts"]["tokenizer"] = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{language_code}")
265
+ model_cache["tts"]["model"].to(device)
266
+ model_cache["tts"]["language"] = language_code
267
+ model_cache["tts"]["status"] = "loaded"
268
+ logger.info(f"TTS model for {language_code} loaded successfully")
269
+ return True
270
+ except Exception as e:
271
+ logger.error(f"Failed to load TTS model for {language_code}: {str(e)}")
272
+ # Fallback to English TTS if the target language fails
273
+ try:
274
+ logger.info("Falling back to MMS-TTS English model...")
275
+ model_cache["tts"]["model"] = VitsModel.from_pretrained("facebook/mms-tts-eng")
276
+ model_cache["tts"]["tokenizer"] = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
277
+ model_cache["tts"]["model"].to(device)
278
+ model_cache["tts"]["language"] = "eng"
279
+ model_cache["tts"]["status"] = "loaded (fallback)"
280
+ logger.info("Fallback TTS model loaded successfully")
281
+ return True
282
+ except Exception as e2:
283
+ model_cache["tts"]["status"] = "failed"
284
+ logger.error(f"Failed to load fallback TTS model: {str(e2)}")
285
+ return False
286
  except Exception as e:
287
+ model_cache["tts"]["status"] = "failed"
288
+ logger.error(f"Failed to setup TTS model: {str(e)}")
289
+ return False
290
  finally:
291
+ loading_locks["tts"].release()
292
 
 
 
 
 
 
 
 
 
293
 
294
  # Start the background cleanup task
295
  def start_cleanup_task():
 
297
  cleanup_thread.daemon = True
298
  cleanup_thread.start()
299
 
300
+
301
  # Start the background processes when the app starts
302
  @app.on_event("startup")
303
  async def startup_event():
304
  logger.info("Application starting up...")
 
305
  start_cleanup_task()
306
 
307
+
308
  @app.get("/")
309
  async def root():
310
  """Root endpoint for default health check"""
311
  logger.info("Root endpoint requested")
312
  return {"status": "healthy"}
313
 
314
+
315
  @app.get("/health")
316
  async def health_check():
317
  """Health check endpoint that always returns successfully"""
 
318
  logger.info("Health check requested")
319
  return {
320
  "status": "healthy",
321
+ "model_status": {
322
+ "stt_whisper": model_cache["stt_whisper"]["status"],
323
+ "stt_mms": model_cache["stt_mms"]["status"],
324
+ "mt": model_cache["mt"]["status"],
325
+ "tts": model_cache["tts"]["status"],
326
+ "tts_language": model_cache["tts"]["language"]
327
+ }
328
  }
329
 
330
+
331
+ @app.post("/update-languages")
332
+ async def update_languages(source_lang: str = Form(...), target_lang: str = Form(...)):
333
+ """
334
+ Update the language settings for translation services
335
+ Will trigger loading of necessary models if not already loaded
336
+ """
337
+ if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
338
+ raise HTTPException(status_code=400, detail="Invalid language selected")
339
+
340
+ source_code = LANGUAGE_MAPPING[source_lang]
341
+ target_code = LANGUAGE_MAPPING[target_lang]
342
+
343
+ # Determine which STT model to use based on the source language
344
+ if source_code in ["eng", "tgl"]:
345
+ # Load Whisper for English or Tagalog
346
+ if not load_whisper_model():
347
+ return {"status": "pending", "message": "Whisper model loading in progress"}
348
+ else:
349
+ # Load MMS for other Philippine languages
350
+ if not load_mms_stt_model():
351
+ return {"status": "pending", "message": "MMS STT model loading in progress"}
352
+
353
+ # Load the MT model if not already loaded
354
+ if not load_mt_model():
355
+ return {"status": "pending", "message": "MT model loading in progress"}
356
+
357
+ # Load the appropriate TTS model for the target language
358
+ if not load_tts_model(target_code):
359
+ return {"status": "pending", "message": "TTS model loading in progress"}
360
+
361
+ logger.info(f"Languages updated to {source_lang} → {target_lang}")
362
+ return {"status": "success", "message": f"Languages updated to {source_lang} → {target_lang}"}
363
+
364
+
365
  @app.post("/synthesize-speech")
366
  async def synthesize_speech(text: str = Form(...), language: str = Form(...)):
367
  """Endpoint to synthesize speech from text without translation"""
368
  if language not in LANGUAGE_MAPPING:
369
  raise HTTPException(status_code=400, detail="Invalid language selected")
370
 
 
 
371
  language_code = LANGUAGE_MAPPING[language]
372
+ request_id = str(uuid.uuid4())
373
 
374
+ # Load the TTS model for the requested language
375
+ if not load_tts_model(language_code):
376
  return {
377
  "request_id": request_id,
378
+ "status": "pending",
379
+ "message": "TTS model loading in progress. Please try again in a moment."
 
 
380
  }
381
 
382
+ try:
383
+ device = "cuda" if torch.cuda.is_available() else "cpu"
384
+ inputs = model_cache["tts"]["tokenizer"](text, return_tensors="pt").to(device)
385
+
386
+ with torch.no_grad():
387
+ output = model_cache["tts"]["model"](**inputs)
388
+
389
+ speech = output.waveform.cpu().numpy().squeeze()
390
+ speech = (speech * 32767).astype(np.int16)
391
+ sample_rate = model_cache["tts"]["model"].config.sampling_rate
392
+
393
+ # Save the audio as a WAV file
394
+ output_filename = f"{request_id}.wav"
395
+ output_path = os.path.join(AUDIO_DIR, output_filename)
396
+ save_pcm_to_wav(speech.tolist(), sample_rate, output_path)
397
+ logger.info(f"Saved synthesized audio to {output_path}")
398
+
399
+ # Generate a URL to the WAV file
400
+ output_audio_url = f"https://jerich-talklasapp.hf.space/audio_output/{output_filename}"
401
+
402
+ return {
403
+ "request_id": request_id,
404
+ "status": "completed",
405
+ "message": "Speech synthesis completed successfully",
406
+ "text": text,
407
+ "output_audio": output_audio_url
408
+ }
409
 
410
+ except Exception as e:
411
+ logger.error(f"Error during speech synthesis: {str(e)}")
412
  return {
413
  "request_id": request_id,
414
  "status": "failed",
415
+ "message": f"Speech synthesis failed: {str(e)}",
416
+ "text": text,
417
+ "output_audio": None
418
  }
419
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
 
421
  @app.post("/translate-text")
422
  async def translate_text(text: str = Form(...), source_lang: str = Form(...), target_lang: str = Form(...)):
423
  """Endpoint to translate text and convert to speech"""
 
 
424
  if not text:
425
  raise HTTPException(status_code=400, detail="No text provided")
426
  if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
 
429
  logger.info(f"Translate-text requested: {text} from {source_lang} to {target_lang}")
430
  request_id = str(uuid.uuid4())
431
 
432
+ # Load the MT model if not already loaded
433
+ if not load_mt_model():
434
+ return {
435
+ "request_id": request_id,
436
+ "status": "pending",
437
+ "message": "MT model loading in progress. Please try again in a moment.",
438
+ "source_text": text,
439
+ "translated_text": "Translation not available yet",
440
+ "output_audio": None,
441
+ "contains_inappropriate_content": False
442
+ }
443
+
444
  # Translate the text
445
  source_code = LANGUAGE_MAPPING[source_lang]
446
  target_code = LANGUAGE_MAPPING[target_lang]
447
  translated_text = "Translation not available"
448
+ contains_inappropriate = False
449
 
450
+ try:
451
+ source_nllb_code = NLLB_LANGUAGE_CODES[source_code]
452
+ target_nllb_code = NLLB_LANGUAGE_CODES[target_code]
453
+ model_cache["mt"]["tokenizer"].src_lang = source_nllb_code
454
+ device = "cuda" if torch.cuda.is_available() else "cpu"
455
+ inputs = model_cache["mt"]["tokenizer"](text, return_tensors="pt").to(device)
456
+ with torch.no_grad():
457
+ generated_tokens = model_cache["mt"]["model"].generate(
458
+ **inputs,
459
+ forced_bos_token_id=model_cache["mt"]["tokenizer"].convert_tokens_to_ids(target_nllb_code),
460
+ max_length=448
461
+ )
462
+ translated_text = model_cache["mt"]["tokenizer"].batch_decode(generated_tokens, skip_special_tokens=True)[0]
463
+ logger.info(f"Translation completed: {translated_text}")
464
+
465
+ # Check for inappropriate content
466
+ contains_inappropriate = detect_inappropriate_content(translated_text)
467
+ if contains_inappropriate:
468
+ logger.warning(f"Inappropriate content detected in translation")
469
+
470
+ except Exception as e:
471
+ logger.error(f"Error during translation: {str(e)}")
472
+ translated_text = f"Translation failed: {str(e)}"
473
+ return {
474
+ "request_id": request_id,
475
+ "status": "failed",
476
+ "message": f"Translation failed: {str(e)}",
477
+ "source_text": text,
478
+ "translated_text": translated_text,
479
+ "output_audio": None,
480
+ "contains_inappropriate_content": contains_inappropriate
481
+ }
482
+
483
+ # Load the TTS model for the target language
484
+ if not load_tts_model(target_code):
485
+ return {
486
+ "request_id": request_id,
487
+ "status": "partial",
488
+ "message": "Translation completed, but TTS model is loading. Please try again for audio.",
489
+ "source_text": text,
490
+ "translated_text": translated_text,
491
+ "output_audio": None,
492
+ "contains_inappropriate_content": contains_inappropriate
493
+ }
494
+
495
  # Convert translated text to speech
 
 
496
  output_audio_url = None
497
+ try:
498
+ inputs = model_cache["tts"]["tokenizer"](translated_text, return_tensors="pt").to(device)
499
+ with torch.no_grad():
500
+ output = model_cache["tts"]["model"](**inputs)
501
+ speech = output.waveform.cpu().numpy().squeeze()
502
+ speech = (speech * 32767).astype(np.int16)
503
+ sample_rate = model_cache["tts"]["model"].config.sampling_rate
504
+
505
  # Save the audio as a WAV file
506
  output_filename = f"{request_id}.wav"
507
  output_path = os.path.join(AUDIO_DIR, output_filename)
508
  save_pcm_to_wav(speech.tolist(), sample_rate, output_path)
509
+ logger.info(f"Saved synthesized audio to {output_path}")
510
+
511
  # Generate a URL to the WAV file
512
+ output_audio_url = f"https://jerich-talklasapp.hf.space/audio_output/{output_filename}"
513
  logger.info("TTS conversion completed")
514
+ except Exception as e:
515
+ logger.error(f"Error during TTS conversion: {str(e)}")
516
+ output_audio_url = None
517
 
518
  return {
519
  "request_id": request_id,
520
+ "status": "completed" if output_audio_url else "partial",
521
+ "message": "Translation and TTS completed" if output_audio_url else
522
+ "Translation completed but TTS failed",
523
  "source_text": text,
524
  "translated_text": translated_text,
525
  "output_audio": output_audio_url,
526
+ "contains_inappropriate_content": contains_inappropriate
527
  }
528
 
529
+
530
  @app.post("/translate-audio")
531
  async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form(...), target_lang: str = Form(...)):
532
  """Endpoint to transcribe, translate, and convert audio to speech"""
 
 
533
  if not audio:
534
  raise HTTPException(status_code=400, detail="No audio file provided")
535
  if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
 
538
  logger.info(f"Translate-audio requested: {audio.filename} from {source_lang} to {target_lang}")
539
  request_id = str(uuid.uuid4())
540
 
 
541
  source_code = LANGUAGE_MAPPING[source_lang]
542
+ target_code = LANGUAGE_MAPPING[target_lang]
543
+
544
+ # Determine which STT model to use based on source language
545
+ use_whisper = source_code in ["eng", "tgl"]
546
 
547
+ # Ensure the appropriate STT model is loaded
548
+ if use_whisper:
549
+ if not load_whisper_model():
 
550
  return {
551
  "request_id": request_id,
552
+ "status": "pending",
553
+ "message": "Whisper STT model loading in progress. Please try again in a moment.",
554
+ "source_text": "Transcription not available yet",
555
+ "translated_text": "Translation not available yet",
556
  "output_audio": None,
557
+ "contains_inappropriate_content": False
558
  }
559
+ else:
560
+ if not load_mms_stt_model():
 
 
 
561
  return {
562
  "request_id": request_id,
563
+ "status": "pending",
564
+ "message": "MMS STT model loading in progress. Please try again in a moment.",
565
+ "source_text": "Transcription not available yet",
566
+ "translated_text": "Translation not available yet",
567
  "output_audio": None,
568
+ "contains_inappropriate_content": False
569
  }
 
570
 
571
  # Save the uploaded audio to a temporary file
572
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
 
576
  transcription = "Transcription not available"
577
  translated_text = "Translation not available"
578
  output_audio_url = None
579
+ contains_inappropriate = False
580
 
581
  try:
582
  # Step 1: Load and resample the audio using torchaudio
 
600
  "source_text": "No speech detected",
601
  "translated_text": "No translation available",
602
  "output_audio": None,
603
+ "contains_inappropriate_content": False
604
  }
605
 
606
  # Step 3: Transcribe the audio (STT)
607
  device = "cuda" if torch.cuda.is_available() else "cpu"
608
+ logger.info(f"Using device: {device}")
609
 
610
  if use_whisper:
611
+ # Use Whisper for English/Tagalog
612
+ stt_processor = model_cache["stt_whisper"]["processor"]
613
+ stt_model = model_cache["stt_whisper"]["model"]
614
+
615
+ inputs = stt_processor(waveform.numpy(), sampling_rate=16000, return_tensors="pt").to(device)
616
+ logger.info("Audio processed with Whisper, generating transcription...")
617
 
 
618
  with torch.no_grad():
619
+ generated_ids = stt_model.generate(**inputs, language="en" if source_code == "eng" else "tl")
620
+ transcription = stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
 
 
 
 
 
621
  else:
622
+ # Use MMS for other Philippine languages
623
+ stt_processor = model_cache["stt_mms"]["processor"]
624
+ stt_model = model_cache["stt_mms"]["model"]
625
+
626
+ # Set the target language for MMS if supported
627
+ if source_code in stt_processor.tokenizer.vocab.keys():
628
+ stt_processor.tokenizer.set_target_lang(source_code)
629
+ stt_model.load_adapter(source_code)
630
 
631
+ inputs = stt_processor(waveform.numpy(), sampling_rate=16000, return_tensors="pt").to(device)
632
+ logger.info("Audio processed with MMS, generating transcription...")
 
633
 
 
634
  with torch.no_grad():
635
+ logits = stt_model(**inputs).logits
636
  predicted_ids = torch.argmax(logits, dim=-1)
637
+ transcription = stt_processor.batch_decode(predicted_ids)[0]
638
+
639
  logger.info(f"Transcription completed: {transcription}")
640
 
641
+ # Step 4: Load the MT model if not already loaded
642
+ if not load_mt_model():
643
+ return {
644
+ "request_id": request_id,
645
+ "status": "partial",
646
+ "message": "Transcription completed, but MT model is loading. Please try again for translation.",
647
+ "source_text": transcription,
648
+ "translated_text": "Translation not available yet",
649
+ "output_audio": None,
650
+ "contains_inappropriate_content": False
651
+ }
652
+
653
+ # Step 5: Translate the transcribed text (MT)
654
+ try:
655
+ source_nllb_code = NLLB_LANGUAGE_CODES[source_code]
656
+ target_nllb_code = NLLB_LANGUAGE_CODES[target_code]
657
+ model_cache["mt"]["tokenizer"].src_lang = source_nllb_code
658
+
659
+ inputs = model_cache["mt"]["tokenizer"](transcription, return_tensors="pt").to(device)
660
+ with torch.no_grad():
661
+ generated_tokens = model_cache["mt"]["model"].generate(
662
+ **inputs,
663
+ forced_bos_token_id=model_cache["mt"]["tokenizer"].convert_tokens_to_ids(target_nllb_code),
664
+ max_length=448
665
+ )
666
+ translated_text = model_cache["mt"]["tokenizer"].batch_decode(generated_tokens, skip_special_tokens=True)[0]
667
+ logger.info(f"Translation completed: {translated_text}")
668
+
669
+ # Check for inappropriate content
670
+ contains_inappropriate = detect_inappropriate_content(translated_text)
671
+ if contains_inappropriate:
672
+ logger.warning(f"Inappropriate content detected in translation")
673
+
674
+ except Exception as e:
675
+ logger.error(f"Error during translation: {str(e)}")
676
+ translated_text = f"Translation failed: {str(e)}"
677
+ return {
678
+ "request_id": request_id,
679
+ "status": "partial",
680
+ "message": f"Transcription completed but translation failed: {str(e)}",
681
+ "source_text": transcription,
682
+ "translated_text": translated_text,
683
+ "output_audio": None,
684
+ "contains_inappropriate_content": False
685
+ }
686
 
687
+ # Step 6: Load the TTS model for the target language
688
+ if not load_tts_model(target_code):
689
+ return {
690
+ "request_id": request_id,
691
+ "status": "partial",
692
+ "message": "Transcription and translation completed, but TTS model is loading.",
693
+ "source_text": transcription,
694
+ "translated_text": translated_text,
695
+ "output_audio": None,
696
+ "contains_inappropriate_content": contains_inappropriate
697
+ }
698
 
699
+ # Step 7: Convert translated text to speech (TTS)
700
+ try:
701
+ inputs = model_cache["tts"]["tokenizer"](translated_text, return_tensors="pt").to(device)
702
+ with torch.no_grad():
703
+ output = model_cache["tts"]["model"](**inputs)
704
+ speech = output.waveform.cpu().numpy().squeeze()
705
+ speech = (speech * 32767).astype(np.int16)
706
+ sample_rate = model_cache["tts"]["model"].config.sampling_rate
707
  # Save the audio as a WAV file
708
  output_filename = f"{request_id}.wav"
709
  output_path = os.path.join(AUDIO_DIR, output_filename)
710
  save_pcm_to_wav(speech.tolist(), sample_rate, output_path)
711
+ logger.info(f"Saved synthesized audio to {output_path}")
712
+
713
  # Generate a URL to the WAV file
714
+ output_audio_url = f"https://jerich-talklasapp.hf.space/audio_output/{output_filename}"
715
  logger.info("TTS conversion completed")
716
+ except Exception as e:
717
+ logger.error(f"Error during TTS conversion: {str(e)}")
718
+ output_audio_url = None
719
 
720
  return {
721
  "request_id": request_id,
722
+ "status": "completed" if output_audio_url else "partial",
723
+ "message": "Transcription, translation, and TTS completed" if output_audio_url else
724
+ "Transcription and translation completed but TTS failed",
725
  "source_text": transcription,
726
  "translated_text": translated_text,
727
  "output_audio": output_audio_url,
728
+ "contains_inappropriate_content": contains_inappropriate
729
  }
730
  except Exception as e:
731
  logger.error(f"Error during processing: {str(e)}")
 
736
  "source_text": transcription,
737
  "translated_text": translated_text,
738
  "output_audio": output_audio_url,
739
+ "contains_inappropriate_content": contains_inappropriate
740
  }
741
  finally:
742
  logger.info(f"Cleaning up temporary file: {temp_path}")
743
+ try:
744
+ os.unlink(temp_path)
745
+ except Exception as e:
746
+ logger.error(f"Error deleting temporary file: {str(e)}")
747
+
748
+
749
+ # Add a method to check if text contains inappropriate content
750
+ @app.post("/check-content")
751
+ async def check_content(text: str = Form(...)):
752
+ """
753
+ Check if the provided text contains inappropriate content
754
+ """
755
+ contains_inappropriate = detect_inappropriate_content(text)
756
+ return {
757
+ "text": text,
758
+ "contains_inappropriate_content": contains_inappropriate
759
+ }
760
+
761
 
762
  if __name__ == "__main__":
763
  import uvicorn