Jerich commited on
Commit
b0c2331
·
verified ·
1 Parent(s): 8157595

Replaced librosa with torchaudio for audio loading and resampling. Added speech detection (energy-based or webrtcvad for accuracy). Improved /translate-audio endpoint to handle silent audio gracefully.

Browse files
Files changed (1) hide show
  1. app.py +48 -8
app.py CHANGED
@@ -9,7 +9,7 @@ import uuid
9
  import torch
10
  import numpy as np
11
  import soundfile as sf
12
- import librosa
13
  import wave
14
  import time
15
  from fastapi import FastAPI, HTTPException, UploadFile, File, Form, BackgroundTasks
@@ -80,6 +80,30 @@ def save_pcm_to_wav(pcm_data: list, sample_rate: int, output_path: str):
80
  # Write the 16-bit PCM data as bytes (little-endian)
81
  wav_file.writeframes(pcm_array.tobytes())
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  # Function to clean up old audio files
84
  def cleanup_old_audio_files():
85
  logger.info("Starting cleanup of old audio files...")
@@ -417,17 +441,33 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
417
  output_audio_url = None
418
 
419
  try:
420
- # Step 1: Transcribe the audio (STT)
421
  logger.info(f"Reading audio file: {temp_path}")
422
- waveform, sample_rate = sf.read(temp_path)
423
  logger.info(f"Audio loaded: sample_rate={sample_rate}, waveform_shape={waveform.shape}")
 
 
424
  if sample_rate != 16000:
425
  logger.info(f"Resampling audio from {sample_rate} Hz to 16000 Hz")
426
- waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000)
427
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
  device = "cuda" if torch.cuda.is_available() else "cpu"
429
  logger.info(f"Using device: {device}")
430
- inputs = stt_processor(waveform, sampling_rate=16000, return_tensors="pt").to(device)
431
  logger.info("Audio processed, generating transcription...")
432
 
433
  with torch.no_grad():
@@ -442,7 +482,7 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
442
  transcription = stt_processor.batch_decode(predicted_ids)[0]
443
  logger.info(f"Transcription completed: {transcription}")
444
 
445
- # Step 2: Translate the transcribed text (MT)
446
  source_code = LANGUAGE_MAPPING[source_lang]
447
  target_code = LANGUAGE_MAPPING[target_lang]
448
 
@@ -466,7 +506,7 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
466
  else:
467
  logger.warning("MT model not loaded, skipping translation")
468
 
469
- # Step 3: Convert translated text to speech (TTS)
470
  if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
471
  try:
472
  inputs = tts_tokenizer(translated_text, return_tensors="pt").to(device)
 
9
  import torch
10
  import numpy as np
11
  import soundfile as sf
12
+ import torchaudio
13
  import wave
14
  import time
15
  from fastapi import FastAPI, HTTPException, UploadFile, File, Form, BackgroundTasks
 
80
  # Write the 16-bit PCM data as bytes (little-endian)
81
  wav_file.writeframes(pcm_array.tobytes())
82
 
83
+ # Function to detect speech using an energy-based approach
84
+ def detect_speech(waveform: torch.Tensor, sample_rate: int, threshold: float = 0.01, min_speech_duration: float = 0.5) -> bool:
85
+ """
86
+ Detects if the audio contains speech using an energy-based approach.
87
+ Returns True if speech is detected, False otherwise.
88
+ """
89
+ # Convert waveform to numpy array
90
+ waveform_np = waveform.numpy()
91
+ if waveform_np.ndim > 1:
92
+ waveform_np = waveform_np.mean(axis=0) # Convert stereo to mono
93
+
94
+ # Compute RMS energy
95
+ rms = np.sqrt(np.mean(waveform_np**2))
96
+ logger.info(f"RMS energy: {rms}")
97
+
98
+ # Check if RMS energy exceeds the threshold
99
+ if rms < threshold:
100
+ logger.info("No speech detected: RMS energy below threshold")
101
+ return False
102
+
103
+ # Optionally, check for minimum speech duration (requires more sophisticated VAD)
104
+ # For now, we assume if RMS is above threshold, there is speech
105
+ return True
106
+
107
  # Function to clean up old audio files
108
  def cleanup_old_audio_files():
109
  logger.info("Starting cleanup of old audio files...")
 
441
  output_audio_url = None
442
 
443
  try:
444
+ # Step 1: Load and resample the audio using torchaudio
445
  logger.info(f"Reading audio file: {temp_path}")
446
+ waveform, sample_rate = torchaudio.load(temp_path)
447
  logger.info(f"Audio loaded: sample_rate={sample_rate}, waveform_shape={waveform.shape}")
448
+
449
+ # Resample to 16 kHz if needed (required by Whisper and MMS models)
450
  if sample_rate != 16000:
451
  logger.info(f"Resampling audio from {sample_rate} Hz to 16000 Hz")
452
+ resampler = torchaudio.transforms.Resample(sample_rate, 16000)
453
+ waveform = resampler(waveform)
454
+ sample_rate = 16000
455
+
456
+ # Step 2: Detect speech
457
+ if not detect_speech(waveform, sample_rate):
458
+ return {
459
+ "request_id": request_id,
460
+ "status": "failed",
461
+ "message": "No speech detected in the audio.",
462
+ "source_text": "No speech detected",
463
+ "translated_text": "No translation available",
464
+ "output_audio": None
465
+ }
466
+
467
+ # Step 3: Transcribe the audio (STT)
468
  device = "cuda" if torch.cuda.is_available() else "cpu"
469
  logger.info(f"Using device: {device}")
470
+ inputs = stt_processor(waveform.numpy(), sampling_rate=16000, return_tensors="pt").to(device)
471
  logger.info("Audio processed, generating transcription...")
472
 
473
  with torch.no_grad():
 
482
  transcription = stt_processor.batch_decode(predicted_ids)[0]
483
  logger.info(f"Transcription completed: {transcription}")
484
 
485
+ # Step 4: Translate the transcribed text (MT)
486
  source_code = LANGUAGE_MAPPING[source_lang]
487
  target_code = LANGUAGE_MAPPING[target_lang]
488
 
 
506
  else:
507
  logger.warning("MT model not loaded, skipping translation")
508
 
509
+ # Step 5: Convert translated text to speech (TTS)
510
  if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
511
  try:
512
  inputs = tts_tokenizer(translated_text, return_tensors="pt").to(device)