Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -116,7 +116,7 @@ def load_models_task():
|
|
116 |
logger.info("TTS model loaded successfully")
|
117 |
model_status["tts"] = "loaded"
|
118 |
except Exception as e:
|
119 |
-
logger.error(f"Failed to load TTS model: {str(e)}")
|
120 |
# Fallback to English TTS if the target language fails
|
121 |
try:
|
122 |
logger.info("Falling back to MMS-TTS English model...")
|
@@ -304,8 +304,12 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
|
|
304 |
temp_file.write(await audio.read())
|
305 |
temp_path = temp_file.name
|
306 |
|
|
|
|
|
|
|
|
|
307 |
try:
|
308 |
-
#
|
309 |
logger.info(f"Reading audio file: {temp_path}")
|
310 |
waveform, sample_rate = sf.read(temp_path)
|
311 |
logger.info(f"Audio loaded: sample_rate={sample_rate}, waveform_shape={waveform.shape}")
|
@@ -313,7 +317,6 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
|
|
313 |
logger.info(f"Resampling audio from {sample_rate} Hz to 16000 Hz")
|
314 |
waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000)
|
315 |
|
316 |
-
# Process the audio with Whisper (STT)
|
317 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
318 |
logger.info(f"Using device: {device}")
|
319 |
inputs = stt_processor(waveform, sampling_rate=16000, return_tensors="pt").to(device)
|
@@ -323,10 +326,9 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
|
|
323 |
transcription = stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
324 |
logger.info(f"Transcription completed: {transcription}")
|
325 |
|
326 |
-
# Translate the transcribed text
|
327 |
source_code = LANGUAGE_MAPPING[source_lang]
|
328 |
target_code = LANGUAGE_MAPPING[target_lang]
|
329 |
-
translated_text = "Translation not available"
|
330 |
|
331 |
if model_status["mt"] == "loaded" and mt_model is not None and mt_tokenizer is not None:
|
332 |
try:
|
@@ -348,11 +350,43 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
|
|
348 |
else:
|
349 |
logger.warning("MT model not loaded, skipping translation")
|
350 |
|
351 |
-
# Convert translated text to speech
|
352 |
-
output_audio = None
|
353 |
if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
|
354 |
try:
|
355 |
inputs = tts_tokenizer(translated_text, return_tensors="pt", clean_up_tokenization_spaces=True).to(device)
|
356 |
with torch.no_grad():
|
357 |
output = tts_model(**inputs)
|
358 |
-
speech = output.waveform.cpu().numpy().squeeze
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
logger.info("TTS model loaded successfully")
|
117 |
model_status["tts"] = "loaded"
|
118 |
except Exception as e:
|
119 |
+
logger.error(f"Failed to load TTS model for Tagalog: {str(e)}")
|
120 |
# Fallback to English TTS if the target language fails
|
121 |
try:
|
122 |
logger.info("Falling back to MMS-TTS English model...")
|
|
|
304 |
temp_file.write(await audio.read())
|
305 |
temp_path = temp_file.name
|
306 |
|
307 |
+
transcription = "Transcription not available"
|
308 |
+
translated_text = "Translation not available"
|
309 |
+
output_audio = None
|
310 |
+
|
311 |
try:
|
312 |
+
# Step 1: Transcribe the audio (STT)
|
313 |
logger.info(f"Reading audio file: {temp_path}")
|
314 |
waveform, sample_rate = sf.read(temp_path)
|
315 |
logger.info(f"Audio loaded: sample_rate={sample_rate}, waveform_shape={waveform.shape}")
|
|
|
317 |
logger.info(f"Resampling audio from {sample_rate} Hz to 16000 Hz")
|
318 |
waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000)
|
319 |
|
|
|
320 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
321 |
logger.info(f"Using device: {device}")
|
322 |
inputs = stt_processor(waveform, sampling_rate=16000, return_tensors="pt").to(device)
|
|
|
326 |
transcription = stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
327 |
logger.info(f"Transcription completed: {transcription}")
|
328 |
|
329 |
+
# Step 2: Translate the transcribed text (MT)
|
330 |
source_code = LANGUAGE_MAPPING[source_lang]
|
331 |
target_code = LANGUAGE_MAPPING[target_lang]
|
|
|
332 |
|
333 |
if model_status["mt"] == "loaded" and mt_model is not None and mt_tokenizer is not None:
|
334 |
try:
|
|
|
350 |
else:
|
351 |
logger.warning("MT model not loaded, skipping translation")
|
352 |
|
353 |
+
# Step 3: Convert translated text to speech (TTS)
|
|
|
354 |
if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
|
355 |
try:
|
356 |
inputs = tts_tokenizer(translated_text, return_tensors="pt", clean_up_tokenization_spaces=True).to(device)
|
357 |
with torch.no_grad():
|
358 |
output = tts_model(**inputs)
|
359 |
+
speech = output.waveform.cpu().numpy().squeeze()
|
360 |
+
speech = (speech * 32767).astype(np.int16)
|
361 |
+
output_audio = (tts_model.config.sampling_rate, speech.tolist())
|
362 |
+
logger.info("TTS conversion completed")
|
363 |
+
except Exception as e:
|
364 |
+
logger.error(f"Error during TTS conversion: {str(e)}")
|
365 |
+
output_audio = None
|
366 |
+
|
367 |
+
return {
|
368 |
+
"request_id": request_id,
|
369 |
+
"status": "completed",
|
370 |
+
"message": "Transcription, translation, and TTS completed (or partially completed).",
|
371 |
+
"source_text": transcription,
|
372 |
+
"translated_text": translated_text,
|
373 |
+
"output_audio": output_audio
|
374 |
+
}
|
375 |
+
except Exception as e:
|
376 |
+
logger.error(f"Error during processing: {str(e)}")
|
377 |
+
return {
|
378 |
+
"request_id": request_id,
|
379 |
+
"status": "failed",
|
380 |
+
"message": f"Processing failed: {str(e)}",
|
381 |
+
"source_text": transcription,
|
382 |
+
"translated_text": translated_text,
|
383 |
+
"output_audio": output_audio
|
384 |
+
}
|
385 |
+
finally:
|
386 |
+
logger.info(f"Cleaning up temporary file: {temp_path}")
|
387 |
+
os.unlink(temp_path)
|
388 |
+
|
389 |
+
if __name__ == "__main__":
|
390 |
+
import uvicorn
|
391 |
+
logger.info("Starting Uvicorn server...")
|
392 |
+
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|