Jerich commited on
Commit
41fc714
·
verified ·
1 Parent(s): a5434d9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +611 -0
app.py ADDED
@@ -0,0 +1,611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["HOME"] = "/root"
3
+ os.environ["HF_HOME"] = "/tmp/hf_cache"
4
+
5
+ import logging
6
+ import threading
7
+ import tempfile
8
+ import uuid
9
+ import torch
10
+ import numpy as np
11
+ import soundfile as sf
12
+ import torchaudio
13
+ import wave
14
+ import time
15
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Form, BackgroundTasks
16
+ from fastapi.responses import JSONResponse
17
+ from fastapi.staticfiles import StaticFiles
18
+ from typing import Dict, Any, Optional, Tuple
19
+ from datetime import datetime, timedelta
20
+
21
+ # Configure logging
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger("talklas-api")
24
+
25
+ app = FastAPI(title="Talklas API")
26
+
27
+ # Mount a directory to serve audio files
28
+ AUDIO_DIR = "/tmp/audio_output" # Use /tmp for temporary files
29
+ os.makedirs(AUDIO_DIR, exist_ok=True)
30
+ app.mount("/audio_output", StaticFiles(directory=AUDIO_DIR), name="audio_output")
31
+
32
+ # Global variables to track application state
33
+ models_loaded = False
34
+ loading_in_progress = False
35
+ loading_thread = None
36
+ model_status = {
37
+ "stt": "not_loaded",
38
+ "mt": "not_loaded",
39
+ "tts": "not_loaded"
40
+ }
41
+ error_message = None
42
+ current_tts_language = "tgl" # Track the current TTS language
43
+
44
+ # Model instances
45
+ stt_processor = None
46
+ stt_model = None
47
+ mt_model = None
48
+ mt_tokenizer = None
49
+ tts_model = None
50
+ tts_tokenizer = None
51
+
52
+ # Define the valid languages and mappings
53
+ LANGUAGE_MAPPING = {
54
+ "English": "eng",
55
+ "Tagalog": "tgl",
56
+ "Cebuano": "ceb",
57
+ "Ilocano": "ilo",
58
+ "Waray": "war",
59
+ "Pangasinan": "pag"
60
+ }
61
+
62
+ NLLB_LANGUAGE_CODES = {
63
+ "eng": "eng_Latn",
64
+ "tgl": "tgl_Latn",
65
+ "ceb": "ceb_Latn",
66
+ "ilo": "ilo_Latn",
67
+ "war": "war_Latn",
68
+ "pag": "pag_Latn"
69
+ }
70
+
71
+ # Function to save PCM data as a WAV file
72
+ def save_pcm_to_wav(pcm_data: list, sample_rate: int, output_path: str):
73
+ # Convert pcm_data to a NumPy array of 16-bit integers
74
+ pcm_array = np.array(pcm_data, dtype=np.int16)
75
+
76
+ with wave.open(output_path, 'wb') as wav_file:
77
+ # Set WAV parameters: 1 channel (mono), 2 bytes per sample (16-bit), sample rate
78
+ wav_file.setnchannels(1)
79
+ wav_file.setsampwidth(2) # 16-bit audio
80
+ wav_file.setframerate(sample_rate)
81
+ # Write the 16-bit PCM data as bytes (little-endian)
82
+ wav_file.writeframes(pcm_array.tobytes())
83
+
84
+ # Function to detect speech using an energy-based approach
85
+ def detect_speech(waveform: torch.Tensor, sample_rate: int, threshold: float = 0.01, min_speech_duration: float = 0.5) -> bool:
86
+ """
87
+ Detects if the audio contains speech using an energy-based approach.
88
+ Returns True if speech is detected, False otherwise.
89
+ """
90
+ # Convert waveform to numpy array
91
+ waveform_np = waveform.numpy()
92
+ if waveform_np.ndim > 1:
93
+ waveform_np = waveform_np.mean(axis=0) # Convert stereo to mono
94
+
95
+ # Compute RMS energy
96
+ rms = np.sqrt(np.mean(waveform_np**2))
97
+ logger.info(f"RMS energy: {rms}")
98
+
99
+ # Check if RMS energy exceeds the threshold
100
+ if rms < threshold:
101
+ logger.info("No speech detected: RMS energy below threshold")
102
+ return False
103
+
104
+ # Optionally, check for minimum speech duration (requires more sophisticated VAD)
105
+ # For now, we assume if RMS is above threshold, there is speech
106
+ return True
107
+
108
+ # Function to clean up old audio files
109
+ def cleanup_old_audio_files():
110
+ logger.info("Starting cleanup of old audio files...")
111
+ expiration_time = datetime.now() - timedelta(minutes=10) # Files older than 10 minutes
112
+ for filename in os.listdir(AUDIO_DIR):
113
+ file_path = os.path.join(AUDIO_DIR, filename)
114
+ if os.path.isfile(file_path):
115
+ file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path))
116
+ if file_mtime < expiration_time:
117
+ try:
118
+ os.unlink(file_path)
119
+ logger.info(f"Deleted old audio file: {file_path}")
120
+ except Exception as e:
121
+ logger.error(f"Error deleting file {file_path}: {str(e)}")
122
+
123
+ # Background task to periodically clean up audio files
124
+ def schedule_cleanup():
125
+ while True:
126
+ cleanup_old_audio_files()
127
+ time.sleep(300) # Run every 5 minutes (300 seconds)
128
+
129
+ # Function to load models in background
130
+ def load_models_task():
131
+ global models_loaded, loading_in_progress, model_status, error_message
132
+ global stt_processor, stt_model, mt_model, mt_tokenizer, tts_model, tts_tokenizer
133
+
134
+ try:
135
+ loading_in_progress = True
136
+
137
+ # Load STT model (MMS with fallback to Whisper)
138
+ logger.info("Starting to load STT model...")
139
+ from transformers import AutoProcessor, AutoModelForCTC, WhisperProcessor, WhisperForConditionalGeneration
140
+
141
+ try:
142
+ logger.info("Loading MMS STT model...")
143
+ model_status["stt"] = "loading"
144
+ stt_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all")
145
+ stt_model = AutoModelForCTC.from_pretrained("facebook/mms-1b-all")
146
+ device = "cuda" if torch.cuda.is_available() else "cpu"
147
+ stt_model.to(device)
148
+ logger.info("MMS STT model loaded successfully")
149
+ model_status["stt"] = "loaded_mms"
150
+ except Exception as mms_error:
151
+ logger.error(f"Failed to load MMS STT model: {str(mms_error)}")
152
+ logger.info("Falling back to Whisper STT model...")
153
+ try:
154
+ stt_processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
155
+ stt_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
156
+ stt_model.to(device)
157
+ logger.info("Whisper STT model loaded successfully as fallback")
158
+ model_status["stt"] = "loaded_whisper"
159
+ except Exception as whisper_error:
160
+ logger.error(f"Failed to load Whisper STT model: {str(whisper_error)}")
161
+ model_status["stt"] = "failed"
162
+ error_message = f"STT model loading failed: MMS error: {str(mms_error)}, Whisper error: {str(whisper_error)}"
163
+ return
164
+
165
+ # Load MT model
166
+ logger.info("Starting to load MT model...")
167
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
168
+
169
+ try:
170
+ logger.info("Loading NLLB-200-distilled-600M model...")
171
+ model_status["mt"] = "loading"
172
+ mt_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
173
+ mt_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
174
+ mt_model.to(device)
175
+ logger.info("MT model loaded successfully")
176
+ model_status["mt"] = "loaded"
177
+ except Exception as e:
178
+ logger.error(f"Failed to load MT model: {str(e)}")
179
+ model_status["mt"] = "failed"
180
+ error_message = f"MT model loading failed: {str(e)}"
181
+ return
182
+
183
+ # Load TTS model (default to Tagalog, will be updated dynamically)
184
+ logger.info("Starting to load TTS model...")
185
+ from transformers import VitsModel, AutoTokenizer
186
+
187
+ try:
188
+ logger.info("Loading MMS-TTS model for Tagalog...")
189
+ model_status["tts"] = "loading"
190
+ tts_model = VitsModel.from_pretrained("facebook/mms-tts-tgl")
191
+ tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-tgl")
192
+ tts_model.to(device)
193
+ logger.info("TTS model loaded successfully")
194
+ model_status["tts"] = "loaded"
195
+ except Exception as e:
196
+ logger.error(f"Failed to load TTS model for Tagalog: {str(e)}")
197
+ # Fallback to English TTS if the target language fails
198
+ try:
199
+ logger.info("Falling back to MMS-TTS English model...")
200
+ tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
201
+ tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
202
+ tts_model.to(device)
203
+ logger.info("Fallback TTS model loaded successfully")
204
+ model_status["tts"] = "loaded (fallback)"
205
+ current_tts_language = "eng"
206
+ except Exception as e2:
207
+ logger.error(f"Failed to load fallback TTS model: {str(e2)}")
208
+ model_status["tts"] = "failed"
209
+ error_message = f"TTS model loading failed: {str(e)} (fallback also failed: {str(e2)})"
210
+ return
211
+
212
+ models_loaded = True
213
+ logger.info("Model loading completed successfully")
214
+
215
+ except Exception as e:
216
+ error_message = str(e)
217
+ logger.error(f"Error in model loading task: {str(e)}")
218
+ finally:
219
+ loading_in_progress = False
220
+
221
+ # Start loading models in background
222
+ def start_model_loading():
223
+ global loading_thread, loading_in_progress
224
+ if not loading_in_progress and not models_loaded:
225
+ loading_in_progress = True
226
+ loading_thread = threading.Thread(target=load_models_task)
227
+ loading_thread.daemon = True
228
+ loading_thread.start()
229
+
230
+ # Start the background cleanup task
231
+ def start_cleanup_task():
232
+ cleanup_thread = threading.Thread(target=schedule_cleanup)
233
+ cleanup_thread.daemon = True
234
+ cleanup_thread.start()
235
+
236
+ # Start the background processes when the app starts
237
+ @app.on_event("startup")
238
+ async def startup_event():
239
+ logger.info("Application starting up...")
240
+ start_model_loading()
241
+ start_cleanup_task()
242
+
243
+ @app.get("/")
244
+ async def root():
245
+ """Root endpoint for default health check"""
246
+ logger.info("Root endpoint requested")
247
+ return {"status": "healthy"}
248
+
249
+ @app.get("/health")
250
+ async def health_check():
251
+ """Health check endpoint that always returns successfully"""
252
+ global models_loaded, loading_in_progress, model_status, error_message
253
+ logger.info("Health check requested")
254
+ return {
255
+ "status": "healthy",
256
+ "models_loaded": models_loaded,
257
+ "loading_in_progress": loading_in_progress,
258
+ "model_status": model_status,
259
+ "error": error_message
260
+ }
261
+
262
+ @app.post("/update-languages")
263
+ async def update_languages(source_lang: str = Form(...), target_lang: str = Form(...)):
264
+ global stt_processor, stt_model, tts_model, tts_tokenizer, current_tts_language
265
+
266
+ if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
267
+ raise HTTPException(status_code=400, detail="Invalid language selected")
268
+
269
+ source_code = LANGUAGE_MAPPING[source_lang]
270
+ target_code = LANGUAGE_MAPPING[target_lang]
271
+
272
+ # Update the STT model based on the source language (MMS or Whisper)
273
+ try:
274
+ logger.info("Updating STT model for source language...")
275
+ from transformers import AutoProcessor, AutoModelForCTC, WhisperProcessor, WhisperForConditionalGeneration
276
+ device = "cuda" if torch.cuda.is_available() else "cpu"
277
+
278
+ try:
279
+ logger.info(f"Loading MMS STT model for {source_code}...")
280
+ stt_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all")
281
+ stt_model = AutoModelForCTC.from_pretrained("facebook/mms-1b-all")
282
+ stt_model.to(device)
283
+ # Set the target language for MMS
284
+ if source_code in stt_processor.tokenizer.vocab.keys():
285
+ stt_processor.tokenizer.set_target_lang(source_code)
286
+ stt_model.load_adapter(source_code)
287
+ logger.info(f"MMS STT model updated to {source_code}")
288
+ model_status["stt"] = "loaded_mms"
289
+ else:
290
+ logger.warning(f"Language {source_code} not supported by MMS, using default")
291
+ model_status["stt"] = "loaded_mms_default"
292
+ except Exception as mms_error:
293
+ logger.error(f"Failed to load MMS STT model for {source_code}: {str(mms_error)}")
294
+ logger.info("Falling back to Whisper STT model...")
295
+ try:
296
+ stt_processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
297
+ stt_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
298
+ stt_model.to(device)
299
+ logger.info("Whisper STT model loaded successfully as fallback")
300
+ model_status["stt"] = "loaded_whisper"
301
+ except Exception as whisper_error:
302
+ logger.error(f"Failed to load Whisper STT model: {str(whisper_error)}")
303
+ model_status["stt"] = "failed"
304
+ error_message = f"STT model update failed: MMS error: {str(mms_error)}, Whisper error: {str(whisper_error)}"
305
+ return {"status": "failed", "error": error_message}
306
+ except Exception as e:
307
+ logger.error(f"Error updating STT model: {str(e)}")
308
+ model_status["stt"] = "failed"
309
+ error_message = f"STT model update failed: {str(e)}"
310
+ return {"status": "failed", "error": error_message}
311
+
312
+ # Update the TTS model based on the target language
313
+ try:
314
+ logger.info(f"Loading MMS-TTS model for {target_code}...")
315
+ from transformers import VitsModel, AutoTokenizer
316
+ tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{target_code}")
317
+ tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{target_code}")
318
+ tts_model.to(device)
319
+ current_tts_language = target_code
320
+ logger.info(f"TTS model updated to {target_code}")
321
+ model_status["tts"] = "loaded"
322
+ except Exception as e:
323
+ logger.error(f"Failed to load TTS model for {target_code}: {str(e)}")
324
+ try:
325
+ logger.info("Falling back to MMS-TTS English model...")
326
+ tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
327
+ tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
328
+ tts_model.to(device)
329
+ current_tts_language = "eng"
330
+ logger.info("Fallback TTS model loaded successfully")
331
+ model_status["tts"] = "loaded (fallback)"
332
+ except Exception as e2:
333
+ logger.error(f"Failed to load fallback TTS model: {str(e2)}")
334
+ model_status["tts"] = "failed"
335
+ error_message = f"TTS model loading failed: {str(e)} (fallback also failed: {str(e2)})"
336
+ return {"status": "failed", "error": error_message}
337
+
338
+ logger.info(f"Updating languages: {source_lang} → {target_lang}")
339
+ return {"status": f"Languages updated to {source_lang} → {target_lang}"}
340
+
341
+ @app.post("/translate-text")
342
+ async def translate_text(text: str = Form(...), source_lang: str = Form(...), target_lang: str = Form(...)):
343
+ """Endpoint to translate text and convert to speech"""
344
+ global mt_model, mt_tokenizer, tts_model, tts_tokenizer, current_tts_language
345
+
346
+ if not text:
347
+ raise HTTPException(status_code=400, detail="No text provided")
348
+ if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
349
+ raise HTTPException(status_code=400, detail="Invalid language selected")
350
+
351
+ logger.info(f"Translate-text requested: {text} from {source_lang} to {target_lang}")
352
+ request_id = str(uuid.uuid4())
353
+
354
+ # Translate the text
355
+ source_code = LANGUAGE_MAPPING[source_lang]
356
+ target_code = LANGUAGE_MAPPING[target_lang]
357
+ translated_text = "Translation not available"
358
+
359
+ if model_status["mt"] == "loaded" and mt_model is not None and mt_tokenizer is not None:
360
+ try:
361
+ source_nllb_code = NLLB_LANGUAGE_CODES[source_code]
362
+ target_nllb_code = NLLB_LANGUAGE_CODES[target_code]
363
+ mt_tokenizer.src_lang = source_nllb_code
364
+ device = "cuda" if torch.cuda.is_available() else "cpu"
365
+ inputs = mt_tokenizer(text, return_tensors="pt").to(device)
366
+ with torch.no_grad():
367
+ generated_tokens = mt_model.generate(
368
+ **inputs,
369
+ forced_bos_token_id=mt_tokenizer.convert_tokens_to_ids(target_nllb_code),
370
+ max_length=448
371
+ )
372
+ translated_text = mt_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
373
+ logger.info(f"Translation completed: {translated_text}")
374
+ except Exception as e:
375
+ logger.error(f"Error during translation: {str(e)}")
376
+ translated_text = f"Translation failed: {str(e)}"
377
+ else:
378
+ logger.warning("MT model not loaded, skipping translation")
379
+
380
+ # Update TTS model if the target language doesn't match the current TTS language
381
+ if current_tts_language != target_code:
382
+ try:
383
+ logger.info(f"Updating TTS model for {target_code}...")
384
+ from transformers import VitsModel, AutoTokenizer
385
+ tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{target_code}")
386
+ tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{target_code}")
387
+ tts_model.to(device)
388
+ current_tts_language = target_code
389
+ logger.info(f"TTS model updated to {target_code}")
390
+ model_status["tts"] = "loaded"
391
+ except Exception as e:
392
+ logger.error(f"Failed to load TTS model for {target_code}: {str(e)}")
393
+ try:
394
+ logger.info("Falling back to MMS-TTS English model...")
395
+ tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
396
+ tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
397
+ tts_model.to(device)
398
+ current_tts_language = "eng"
399
+ logger.info("Fallback TTS model loaded successfully")
400
+ model_status["tts"] = "loaded (fallback)"
401
+ except Exception as e2:
402
+ logger.error(f"Failed to load fallback TTS model: {str(e2)}")
403
+ model_status["tts"] = "failed"
404
+
405
+ # Convert translated text to speech
406
+ output_audio_url = None
407
+ if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
408
+ try:
409
+ inputs = tts_tokenizer(translated_text, return_tensors="pt").to(device)
410
+ with torch.no_grad():
411
+ output = tts_model(**inputs)
412
+ speech = output.waveform.cpu().numpy().squeeze()
413
+ speech = (speech * 32767).astype(np.int16)
414
+ sample_rate = tts_model.config.sampling_rate
415
+
416
+ # Save the audio as a WAV file
417
+ output_filename = f"{request_id}.wav"
418
+ output_path = os.path.join(AUDIO_DIR, output_filename)
419
+ save_pcm_to_wav(speech.tolist(), sample_rate, output_path)
420
+ logger.info(f"Saved synthesized audio to {output_path}")
421
+
422
+ # Generate a URL to the WAV file
423
+ output_audio_url = f"https://jerich-talklasapp.hf.space/audio_output/{output_filename}"
424
+ logger.info("TTS conversion completed")
425
+ except Exception as e:
426
+ logger.error(f"Error during TTS conversion: {str(e)}")
427
+ output_audio_url = None
428
+
429
+ return {
430
+ "request_id": request_id,
431
+ "status": "completed",
432
+ "message": "Translation and TTS completed (or partially completed).",
433
+ "source_text": text,
434
+ "translated_text": translated_text,
435
+ "output_audio": output_audio_url
436
+ }
437
+
438
+ @app.post("/translate-audio")
439
+ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form(...), target_lang: str = Form(...)):
440
+ """Endpoint to transcribe, translate, and convert audio to speech"""
441
+ global stt_processor, stt_model, mt_model, mt_tokenizer, tts_model, tts_tokenizer, current_tts_language
442
+
443
+ if not audio:
444
+ raise HTTPException(status_code=400, detail="No audio file provided")
445
+ if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
446
+ raise HTTPException(status_code=400, detail="Invalid language selected")
447
+
448
+ logger.info(f"Translate-audio requested: {audio.filename} from {source_lang} to {target_lang}")
449
+ request_id = str(uuid.uuid4())
450
+
451
+ # Check if STT model is loaded
452
+ if model_status["stt"] not in ["loaded_mms", "loaded_mms_default", "loaded_whisper"] or stt_processor is None or stt_model is None:
453
+ logger.warning("STT model not loaded, returning placeholder response")
454
+ return {
455
+ "request_id": request_id,
456
+ "status": "processing",
457
+ "message": "STT model not loaded yet. Please try again later.",
458
+ "source_text": "Transcription not available",
459
+ "translated_text": "Translation not available",
460
+ "output_audio": None
461
+ }
462
+
463
+ # Save the uploaded audio to a temporary file
464
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
465
+ temp_file.write(await audio.read())
466
+ temp_path = temp_file.name
467
+
468
+ transcription = "Transcription not available"
469
+ translated_text = "Translation not available"
470
+ output_audio_url = None
471
+
472
+ try:
473
+ # Step 1: Load and resample the audio using torchaudio
474
+ logger.info(f"Reading audio file: {temp_path}")
475
+ waveform, sample_rate = torchaudio.load(temp_path)
476
+ logger.info(f"Audio loaded: sample_rate={sample_rate}, waveform_shape={waveform.shape}")
477
+
478
+ # Resample to 16 kHz if needed (required by Whisper and MMS models)
479
+ if sample_rate != 16000:
480
+ logger.info(f"Resampling audio from {sample_rate} Hz to 16000 Hz")
481
+ resampler = torchaudio.transforms.Resample(sample_rate, 16000)
482
+ waveform = resampler(waveform)
483
+ sample_rate = 16000
484
+
485
+ # Step 2: Detect speech
486
+ if not detect_speech(waveform, sample_rate):
487
+ return {
488
+ "request_id": request_id,
489
+ "status": "failed",
490
+ "message": "No speech detected in the audio.",
491
+ "source_text": "No speech detected",
492
+ "translated_text": "No translation available",
493
+ "output_audio": None
494
+ }
495
+
496
+ # Step 3: Transcribe the audio (STT)
497
+ device = "cuda" if torch.cuda.is_available() else "cpu"
498
+ logger.info(f"Using device: {device}")
499
+ inputs = stt_processor(waveform.numpy(), sampling_rate=16000, return_tensors="pt").to(device)
500
+ logger.info("Audio processed, generating transcription...")
501
+
502
+ with torch.no_grad():
503
+ if model_status["stt"] == "loaded_whisper":
504
+ # Whisper model
505
+ generated_ids = stt_model.generate(**inputs, language="en")
506
+ transcription = stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
507
+ else:
508
+ # MMS model
509
+ logits = stt_model(**inputs).logits
510
+ predicted_ids = torch.argmax(logits, dim=-1)
511
+ transcription = stt_processor.batch_decode(predicted_ids)[0]
512
+ logger.info(f"Transcription completed: {transcription}")
513
+
514
+ # Step 4: Translate the transcribed text (MT)
515
+ source_code = LANGUAGE_MAPPING[source_lang]
516
+ target_code = LANGUAGE_MAPPING[target_lang]
517
+
518
+ if model_status["mt"] == "loaded" and mt_model is not None and mt_tokenizer is not None:
519
+ try:
520
+ source_nllb_code = NLLB_LANGUAGE_CODES[source_code]
521
+ target_nllb_code = NLLB_LANGUAGE_CODES[target_code]
522
+ mt_tokenizer.src_lang = source_nllb_code
523
+ inputs = mt_tokenizer(transcription, return_tensors="pt").to(device)
524
+ with torch.no_grad():
525
+ generated_tokens = mt_model.generate(
526
+ **inputs,
527
+ forced_bos_token_id=mt_tokenizer.convert_tokens_to_ids(target_nllb_code),
528
+ max_length=448
529
+ )
530
+ translated_text = mt_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
531
+ logger.info(f"Translation completed: {translated_text}")
532
+ except Exception as e:
533
+ logger.error(f"Error during translation: {str(e)}")
534
+ translated_text = f"Translation failed: {str(e)}"
535
+ else:
536
+ logger.warning("MT model not loaded, skipping translation")
537
+
538
+ # Step 5: Update TTS model if the target language doesn't match the current TTS language
539
+ if current_tts_language != target_code:
540
+ try:
541
+ logger.info(f"Updating TTS model for {target_code}...")
542
+ from transformers import VitsModel, AutoTokenizer
543
+ tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{target_code}")
544
+ tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{target_code}")
545
+ tts_model.to(device)
546
+ current_tts_language = target_code
547
+ logger.info(f"TTS model updated to {target_code}")
548
+ model_status["tts"] = "loaded"
549
+ except Exception as e:
550
+ logger.error(f"Failed to load TTS model for {target_code}: {str(e)}")
551
+ try:
552
+ logger.info("Falling back to MMS-TTS English model...")
553
+ tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
554
+ tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
555
+ tts_model.to(device)
556
+ current_tts_language = "eng"
557
+ logger.info("Fallback TTS model loaded successfully")
558
+ model_status["tts"] = "loaded (fallback)"
559
+ except Exception as e2:
560
+ logger.error(f"Failed to load fallback TTS model: {str(e2)}")
561
+ model_status["tts"] = "failed"
562
+
563
+ # Step 6: Convert translated text to speech (TTS)
564
+ if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
565
+ try:
566
+ inputs = tts_tokenizer(translated_text, return_tensors="pt").to(device)
567
+ with torch.no_grad():
568
+ output = tts_model(**inputs)
569
+ speech = output.waveform.cpu().numpy().squeeze()
570
+ speech = (speech * 32767).astype(np.int16)
571
+ sample_rate = tts_model.config.sampling_rate
572
+
573
+ # Save the audio as a WAV file
574
+ output_filename = f"{request_id}.wav"
575
+ output_path = os.path.join(AUDIO_DIR, output_filename)
576
+ save_pcm_to_wav(speech.tolist(), sample_rate, output_path)
577
+ logger.info(f"Saved synthesized audio to {output_path}")
578
+
579
+ # Generate a URL to the WAV file
580
+ output_audio_url = f"https://jerich-talklasapp.hf.space/audio_output/{output_filename}"
581
+ logger.info("TTS conversion completed")
582
+ except Exception as e:
583
+ logger.error(f"Error during TTS conversion: {str(e)}")
584
+ output_audio_url = None
585
+
586
+ return {
587
+ "request_id": request_id,
588
+ "status": "completed",
589
+ "message": "Transcription, translation, and TTS completed (or partially completed).",
590
+ "source_text": transcription,
591
+ "translated_text": translated_text,
592
+ "output_audio": output_audio_url
593
+ }
594
+ except Exception as e:
595
+ logger.error(f"Error during processing: {str(e)}")
596
+ return {
597
+ "request_id": request_id,
598
+ "status": "failed",
599
+ "message": f"Processing failed: {str(e)}",
600
+ "source_text": transcription,
601
+ "translated_text": translated_text,
602
+ "output_audio": output_audio_url
603
+ }
604
+ finally:
605
+ logger.info(f"Cleaning up temporary file: {temp_path}")
606
+ os.unlink(temp_path)
607
+
608
+ if __name__ == "__main__":
609
+ import uvicorn
610
+ logger.info("Starting Uvicorn server...")
611
+ uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)