Jerich commited on
Commit
eacfbe2
·
verified ·
1 Parent(s): 98d2781

Update audio handling: Save synthesized audio as WAV, return URL, set 10-min expiration with 5-min cleanup interval

Browse files
Files changed (1) hide show
  1. app.py +76 -11
app.py CHANGED
@@ -10,9 +10,13 @@ import torch
10
  import numpy as np
11
  import soundfile as sf
12
  import librosa
13
- from fastapi import FastAPI, HTTPException, UploadFile, File, Form
 
 
14
  from fastapi.responses import JSONResponse
 
15
  from typing import Dict, Any, Optional, Tuple
 
16
 
17
  # Configure logging
18
  logging.basicConfig(level=logging.INFO)
@@ -20,6 +24,11 @@ logger = logging.getLogger("talklas-api")
20
 
21
  app = FastAPI(title="Talklas API")
22
 
 
 
 
 
 
23
  # Global variables to track application state
24
  models_loaded = False
25
  loading_in_progress = False
@@ -58,6 +67,37 @@ NLLB_LANGUAGE_CODES = {
58
  "pag": "pag_Latn"
59
  }
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  # Function to load models in background
62
  def load_models_task():
63
  global models_loaded, loading_in_progress, model_status, error_message
@@ -158,11 +198,18 @@ def start_model_loading():
158
  loading_thread.daemon = True
159
  loading_thread.start()
160
 
161
- # Start the background process when the app starts
 
 
 
 
 
 
162
  @app.on_event("startup")
163
  async def startup_event():
164
  logger.info("Application starting up...")
165
  start_model_loading()
 
166
 
167
  @app.get("/")
168
  async def root():
@@ -300,7 +347,7 @@ async def translate_text(text: str = Form(...), source_lang: str = Form(...), ta
300
  logger.warning("MT model not loaded, skipping translation")
301
 
302
  # Convert translated text to speech
303
- output_audio = None
304
  if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
305
  try:
306
  inputs = tts_tokenizer(translated_text, return_tensors="pt").to(device)
@@ -308,11 +355,20 @@ async def translate_text(text: str = Form(...), source_lang: str = Form(...), ta
308
  output = tts_model(**inputs)
309
  speech = output.waveform.cpu().numpy().squeeze()
310
  speech = (speech * 32767).astype(np.int16)
311
- output_audio = (tts_model.config.sampling_rate, speech.tolist())
 
 
 
 
 
 
 
 
 
312
  logger.info("TTS conversion completed")
313
  except Exception as e:
314
  logger.error(f"Error during TTS conversion: {str(e)}")
315
- output_audio = None
316
 
317
  return {
318
  "request_id": request_id,
@@ -320,7 +376,7 @@ async def translate_text(text: str = Form(...), source_lang: str = Form(...), ta
320
  "message": "Translation and TTS completed (or partially completed).",
321
  "source_text": text,
322
  "translated_text": translated_text,
323
- "output_audio": output_audio
324
  }
325
 
326
  @app.post("/translate-audio")
@@ -355,7 +411,7 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
355
 
356
  transcription = "Transcription not available"
357
  translated_text = "Translation not available"
358
- output_audio = None
359
 
360
  try:
361
  # Step 1: Transcribe the audio (STT)
@@ -415,11 +471,20 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
415
  output = tts_model(**inputs)
416
  speech = output.waveform.cpu().numpy().squeeze()
417
  speech = (speech * 32767).astype(np.int16)
418
- output_audio = (tts_model.config.sampling_rate, speech.tolist())
 
 
 
 
 
 
 
 
 
419
  logger.info("TTS conversion completed")
420
  except Exception as e:
421
  logger.error(f"Error during TTS conversion: {str(e)}")
422
- output_audio = None
423
 
424
  return {
425
  "request_id": request_id,
@@ -427,7 +492,7 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
427
  "message": "Transcription, translation, and TTS completed (or partially completed).",
428
  "source_text": transcription,
429
  "translated_text": translated_text,
430
- "output_audio": output_audio
431
  }
432
  except Exception as e:
433
  logger.error(f"Error during processing: {str(e)}")
@@ -437,7 +502,7 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
437
  "message": f"Processing failed: {str(e)}",
438
  "source_text": transcription,
439
  "translated_text": translated_text,
440
- "output_audio": output_audio
441
  }
442
  finally:
443
  logger.info(f"Cleaning up temporary file: {temp_path}")
 
10
  import numpy as np
11
  import soundfile as sf
12
  import librosa
13
+ import wave
14
+ import time
15
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Form, BackgroundTasks
16
  from fastapi.responses import JSONResponse
17
+ from fastapi.staticfiles import StaticFiles
18
  from typing import Dict, Any, Optional, Tuple
19
+ from datetime import datetime, timedelta
20
 
21
  # Configure logging
22
  logging.basicConfig(level=logging.INFO)
 
24
 
25
  app = FastAPI(title="Talklas API")
26
 
27
+ # Mount a directory to serve audio files
28
+ AUDIO_DIR = "audio_output"
29
+ os.makedirs(AUDIO_DIR, exist_ok=True)
30
+ app.mount("/audio_output", StaticFiles(directory=AUDIO_DIR), name="audio_output")
31
+
32
  # Global variables to track application state
33
  models_loaded = False
34
  loading_in_progress = False
 
67
  "pag": "pag_Latn"
68
  }
69
 
70
+ # Function to save PCM data as a WAV file
71
+ def save_pcm_to_wav(pcm_data: list, sample_rate: int, output_path: str):
72
+ with wave.open(output_path, 'wb') as wav_file:
73
+ # Set WAV parameters: 1 channel (mono), 2 bytes per sample (16-bit), sample rate
74
+ wav_file.setnchannels(1)
75
+ wav_file.setsampwidth(2) # 16-bit audio
76
+ wav_file.setframerate(sample_rate)
77
+ # Convert PCM data (list of integers) to bytes
78
+ wav_file.writeframes(bytes(pcm_data))
79
+
80
+ # Function to clean up old audio files
81
+ def cleanup_old_audio_files():
82
+ logger.info("Starting cleanup of old audio files...")
83
+ expiration_time = datetime.now() - timedelta(minutes=10) # Files older than 10 minutes
84
+ for filename in os.listdir(AUDIO_DIR):
85
+ file_path = os.path.join(AUDIO_DIR, filename)
86
+ if os.path.isfile(file_path):
87
+ file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path))
88
+ if file_mtime < expiration_time:
89
+ try:
90
+ os.unlink(file_path)
91
+ logger.info(f"Deleted old audio file: {file_path}")
92
+ except Exception as e:
93
+ logger.error(f"Error deleting file {file_path}: {str(e)}")
94
+
95
+ # Background task to periodically clean up audio files
96
+ def schedule_cleanup():
97
+ while True:
98
+ cleanup_old_audio_files()
99
+ time.sleep(300) # Run every 5 minutes (300 seconds)
100
+
101
  # Function to load models in background
102
  def load_models_task():
103
  global models_loaded, loading_in_progress, model_status, error_message
 
198
  loading_thread.daemon = True
199
  loading_thread.start()
200
 
201
+ # Start the background cleanup task
202
+ def start_cleanup_task():
203
+ cleanup_thread = threading.Thread(target=schedule_cleanup)
204
+ cleanup_thread.daemon = True
205
+ cleanup_thread.start()
206
+
207
+ # Start the background processes when the app starts
208
  @app.on_event("startup")
209
  async def startup_event():
210
  logger.info("Application starting up...")
211
  start_model_loading()
212
+ start_cleanup_task()
213
 
214
  @app.get("/")
215
  async def root():
 
347
  logger.warning("MT model not loaded, skipping translation")
348
 
349
  # Convert translated text to speech
350
+ output_audio_url = None
351
  if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
352
  try:
353
  inputs = tts_tokenizer(translated_text, return_tensors="pt").to(device)
 
355
  output = tts_model(**inputs)
356
  speech = output.waveform.cpu().numpy().squeeze()
357
  speech = (speech * 32767).astype(np.int16)
358
+ sample_rate = tts_model.config.sampling_rate
359
+
360
+ # Save the audio as a WAV file
361
+ output_filename = f"{request_id}.wav"
362
+ output_path = os.path.join(AUDIO_DIR, output_filename)
363
+ save_pcm_to_wav(speech.tolist(), sample_rate, output_path)
364
+ logger.info(f"Saved synthesized audio to {output_path}")
365
+
366
+ # Generate a URL to the WAV file
367
+ output_audio_url = f"https://jerich-talklasapp.hf.space/audio_output/{output_filename}"
368
  logger.info("TTS conversion completed")
369
  except Exception as e:
370
  logger.error(f"Error during TTS conversion: {str(e)}")
371
+ output_audio_url = None
372
 
373
  return {
374
  "request_id": request_id,
 
376
  "message": "Translation and TTS completed (or partially completed).",
377
  "source_text": text,
378
  "translated_text": translated_text,
379
+ "output_audio": output_audio_url
380
  }
381
 
382
  @app.post("/translate-audio")
 
411
 
412
  transcription = "Transcription not available"
413
  translated_text = "Translation not available"
414
+ output_audio_url = None
415
 
416
  try:
417
  # Step 1: Transcribe the audio (STT)
 
471
  output = tts_model(**inputs)
472
  speech = output.waveform.cpu().numpy().squeeze()
473
  speech = (speech * 32767).astype(np.int16)
474
+ sample_rate = tts_model.config.sampling_rate
475
+
476
+ # Save the audio as a WAV file
477
+ output_filename = f"{request_id}.wav"
478
+ output_path = os.path.join(AUDIO_DIR, output_filename)
479
+ save_pcm_to_wav(speech.tolist(), sample_rate, output_path)
480
+ logger.info(f"Saved synthesized audio to {output_path}")
481
+
482
+ # Generate a URL to the WAV file
483
+ output_audio_url = f"https://jerich-talklasapp.hf.space/audio_output/{output_filename}"
484
  logger.info("TTS conversion completed")
485
  except Exception as e:
486
  logger.error(f"Error during TTS conversion: {str(e)}")
487
+ output_audio_url = None
488
 
489
  return {
490
  "request_id": request_id,
 
492
  "message": "Transcription, translation, and TTS completed (or partially completed).",
493
  "source_text": transcription,
494
  "translated_text": translated_text,
495
+ "output_audio": output_audio_url
496
  }
497
  except Exception as e:
498
  logger.error(f"Error during processing: {str(e)}")
 
502
  "message": f"Processing failed: {str(e)}",
503
  "source_text": transcription,
504
  "translated_text": translated_text,
505
+ "output_audio": output_audio_url
506
  }
507
  finally:
508
  logger.info(f"Cleaning up temporary file: {temp_path}")