Spaces:
Paused
Paused
Replaced librosa with torchaudio for audio loading and resampling. Added speech detection (energy-based or webrtcvad for accuracy). Improved /translate-audio endpoint to handle silent audio gracefully.
Browse files
app.py
CHANGED
@@ -9,7 +9,7 @@ import uuid
|
|
9 |
import torch
|
10 |
import numpy as np
|
11 |
import soundfile as sf
|
12 |
-
import
|
13 |
import wave
|
14 |
import time
|
15 |
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, BackgroundTasks
|
@@ -80,6 +80,30 @@ def save_pcm_to_wav(pcm_data: list, sample_rate: int, output_path: str):
|
|
80 |
# Write the 16-bit PCM data as bytes (little-endian)
|
81 |
wav_file.writeframes(pcm_array.tobytes())
|
82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
# Function to clean up old audio files
|
84 |
def cleanup_old_audio_files():
|
85 |
logger.info("Starting cleanup of old audio files...")
|
@@ -417,17 +441,33 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
|
|
417 |
output_audio_url = None
|
418 |
|
419 |
try:
|
420 |
-
# Step 1:
|
421 |
logger.info(f"Reading audio file: {temp_path}")
|
422 |
-
waveform, sample_rate =
|
423 |
logger.info(f"Audio loaded: sample_rate={sample_rate}, waveform_shape={waveform.shape}")
|
|
|
|
|
424 |
if sample_rate != 16000:
|
425 |
logger.info(f"Resampling audio from {sample_rate} Hz to 16000 Hz")
|
426 |
-
|
427 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
428 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
429 |
logger.info(f"Using device: {device}")
|
430 |
-
inputs = stt_processor(waveform, sampling_rate=16000, return_tensors="pt").to(device)
|
431 |
logger.info("Audio processed, generating transcription...")
|
432 |
|
433 |
with torch.no_grad():
|
@@ -442,7 +482,7 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
|
|
442 |
transcription = stt_processor.batch_decode(predicted_ids)[0]
|
443 |
logger.info(f"Transcription completed: {transcription}")
|
444 |
|
445 |
-
# Step
|
446 |
source_code = LANGUAGE_MAPPING[source_lang]
|
447 |
target_code = LANGUAGE_MAPPING[target_lang]
|
448 |
|
@@ -466,7 +506,7 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
|
|
466 |
else:
|
467 |
logger.warning("MT model not loaded, skipping translation")
|
468 |
|
469 |
-
# Step
|
470 |
if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
|
471 |
try:
|
472 |
inputs = tts_tokenizer(translated_text, return_tensors="pt").to(device)
|
|
|
9 |
import torch
|
10 |
import numpy as np
|
11 |
import soundfile as sf
|
12 |
+
import torchaudio
|
13 |
import wave
|
14 |
import time
|
15 |
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, BackgroundTasks
|
|
|
80 |
# Write the 16-bit PCM data as bytes (little-endian)
|
81 |
wav_file.writeframes(pcm_array.tobytes())
|
82 |
|
83 |
+
# Function to detect speech using an energy-based approach
|
84 |
+
def detect_speech(waveform: torch.Tensor, sample_rate: int, threshold: float = 0.01, min_speech_duration: float = 0.5) -> bool:
|
85 |
+
"""
|
86 |
+
Detects if the audio contains speech using an energy-based approach.
|
87 |
+
Returns True if speech is detected, False otherwise.
|
88 |
+
"""
|
89 |
+
# Convert waveform to numpy array
|
90 |
+
waveform_np = waveform.numpy()
|
91 |
+
if waveform_np.ndim > 1:
|
92 |
+
waveform_np = waveform_np.mean(axis=0) # Convert stereo to mono
|
93 |
+
|
94 |
+
# Compute RMS energy
|
95 |
+
rms = np.sqrt(np.mean(waveform_np**2))
|
96 |
+
logger.info(f"RMS energy: {rms}")
|
97 |
+
|
98 |
+
# Check if RMS energy exceeds the threshold
|
99 |
+
if rms < threshold:
|
100 |
+
logger.info("No speech detected: RMS energy below threshold")
|
101 |
+
return False
|
102 |
+
|
103 |
+
# Optionally, check for minimum speech duration (requires more sophisticated VAD)
|
104 |
+
# For now, we assume if RMS is above threshold, there is speech
|
105 |
+
return True
|
106 |
+
|
107 |
# Function to clean up old audio files
|
108 |
def cleanup_old_audio_files():
|
109 |
logger.info("Starting cleanup of old audio files...")
|
|
|
441 |
output_audio_url = None
|
442 |
|
443 |
try:
|
444 |
+
# Step 1: Load and resample the audio using torchaudio
|
445 |
logger.info(f"Reading audio file: {temp_path}")
|
446 |
+
waveform, sample_rate = torchaudio.load(temp_path)
|
447 |
logger.info(f"Audio loaded: sample_rate={sample_rate}, waveform_shape={waveform.shape}")
|
448 |
+
|
449 |
+
# Resample to 16 kHz if needed (required by Whisper and MMS models)
|
450 |
if sample_rate != 16000:
|
451 |
logger.info(f"Resampling audio from {sample_rate} Hz to 16000 Hz")
|
452 |
+
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
|
453 |
+
waveform = resampler(waveform)
|
454 |
+
sample_rate = 16000
|
455 |
+
|
456 |
+
# Step 2: Detect speech
|
457 |
+
if not detect_speech(waveform, sample_rate):
|
458 |
+
return {
|
459 |
+
"request_id": request_id,
|
460 |
+
"status": "failed",
|
461 |
+
"message": "No speech detected in the audio.",
|
462 |
+
"source_text": "No speech detected",
|
463 |
+
"translated_text": "No translation available",
|
464 |
+
"output_audio": None
|
465 |
+
}
|
466 |
+
|
467 |
+
# Step 3: Transcribe the audio (STT)
|
468 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
469 |
logger.info(f"Using device: {device}")
|
470 |
+
inputs = stt_processor(waveform.numpy(), sampling_rate=16000, return_tensors="pt").to(device)
|
471 |
logger.info("Audio processed, generating transcription...")
|
472 |
|
473 |
with torch.no_grad():
|
|
|
482 |
transcription = stt_processor.batch_decode(predicted_ids)[0]
|
483 |
logger.info(f"Transcription completed: {transcription}")
|
484 |
|
485 |
+
# Step 4: Translate the transcribed text (MT)
|
486 |
source_code = LANGUAGE_MAPPING[source_lang]
|
487 |
target_code = LANGUAGE_MAPPING[target_lang]
|
488 |
|
|
|
506 |
else:
|
507 |
logger.warning("MT model not loaded, skipping translation")
|
508 |
|
509 |
+
# Step 5: Convert translated text to speech (TTS)
|
510 |
if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
|
511 |
try:
|
512 |
inputs = tts_tokenizer(translated_text, return_tensors="pt").to(device)
|