Jerich commited on
Commit
f56283f
·
verified ·
1 Parent(s): f8dca01

Add MT with NLLB-200-distilled-600M and TTS with MMS-TTS

Browse files
Files changed (1) hide show
  1. app.py +178 -42
app.py CHANGED
@@ -12,7 +12,7 @@ import soundfile as sf
12
  import librosa
13
  from fastapi import FastAPI, HTTPException, UploadFile, File, Form
14
  from fastapi.responses import JSONResponse
15
- from typing import Dict, Any, Optional
16
 
17
  # Configure logging
18
  logging.basicConfig(level=logging.INFO)
@@ -31,11 +31,15 @@ model_status = {
31
  }
32
  error_message = None
33
 
34
- # STT model and processor (will be loaded in background)
35
  stt_processor = None
36
  stt_model = None
 
 
 
 
37
 
38
- # Define the valid languages
39
  LANGUAGE_MAPPING = {
40
  "English": "eng",
41
  "Tagalog": "tgl",
@@ -45,9 +49,19 @@ LANGUAGE_MAPPING = {
45
  "Pangasinan": "pag"
46
  }
47
 
 
 
 
 
 
 
 
 
 
48
  # Function to load models in background
49
  def load_models_task():
50
- global models_loaded, loading_in_progress, model_status, error_message, stt_processor, stt_model
 
51
 
52
  try:
53
  loading_in_progress = True
@@ -71,10 +85,51 @@ def load_models_task():
71
  error_message = f"STT model loading failed: {str(e)}"
72
  return
73
 
74
- # Skip MT and TTS models for now to save memory
75
- model_status["mt"] = "skipped"
76
- model_status["tts"] = "skipped"
77
- logger.info("MT and TTS models skipped to save memory")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  models_loaded = True
80
  logger.info("Model loading completed successfully")
@@ -121,14 +176,46 @@ async def health_check():
121
 
122
  @app.post("/update-languages")
123
  async def update_languages(source_lang: str = Form(...), target_lang: str = Form(...)):
 
 
124
  if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
125
  raise HTTPException(status_code=400, detail="Invalid language selected")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  logger.info(f"Updating languages: {source_lang} → {target_lang}")
127
  return {"status": f"Languages updated to {source_lang} → {target_lang}"}
128
 
129
  @app.post("/translate-text")
130
  async def translate_text(text: str = Form(...), source_lang: str = Form(...), target_lang: str = Form(...)):
131
- """Endpoint that creates a placeholder for text translation"""
 
 
132
  if not text:
133
  raise HTTPException(status_code=400, detail="No text provided")
134
  if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
@@ -136,19 +223,61 @@ async def translate_text(text: str = Form(...), source_lang: str = Form(...), ta
136
 
137
  logger.info(f"Translate-text requested: {text} from {source_lang} to {target_lang}")
138
  request_id = str(uuid.uuid4())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  return {
140
  "request_id": request_id,
141
- "status": "processing",
142
- "message": "Translation not implemented yet (MT model not loaded).",
143
  "source_text": text,
144
- "translated_text": "Translation not available",
145
- "output_audio": None
146
  }
147
 
148
  @app.post("/translate-audio")
149
  async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form(...), target_lang: str = Form(...)):
150
- """Endpoint to transcribe audio using STT"""
151
- global stt_processor, stt_model
152
 
153
  if not audio:
154
  raise HTTPException(status_code=400, detail="No audio file provided")
@@ -184,7 +313,7 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
184
  logger.info(f"Resampling audio from {sample_rate} Hz to 16000 Hz")
185
  waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000)
186
 
187
- # Process the audio with Whisper
188
  device = "cuda" if torch.cuda.is_available() else "cpu"
189
  logger.info(f"Using device: {device}")
190
  inputs = stt_processor(waveform, sampling_rate=16000, return_tensors="pt").to(device)
@@ -192,31 +321,38 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
192
  with torch.no_grad():
193
  generated_ids = stt_model.generate(**inputs)
194
  transcription = stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
195
-
196
  logger.info(f"Transcription completed: {transcription}")
197
- return {
198
- "request_id": request_id,
199
- "status": "completed",
200
- "message": "Transcription completed successfully. Translation and TTS not implemented yet.",
201
- "source_text": transcription,
202
- "translated_text": "Translation not available",
203
- "output_audio": None
204
- }
205
- except Exception as e:
206
- logger.error(f"Error during transcription: {str(e)}")
207
- return {
208
- "request_id": request_id,
209
- "status": "failed",
210
- "message": f"Transcription failed: {str(e)}",
211
- "source_text": "Transcription not available",
212
- "translated_text": "Translation not available",
213
- "output_audio": None
214
- }
215
- finally:
216
- logger.info(f"Cleaning up temporary file: {temp_path}")
217
- os.unlink(temp_path)
218
 
219
- if __name__ == "__main__":
220
- import uvicorn
221
- logger.info("Starting Uvicorn server...")
222
- uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  import librosa
13
  from fastapi import FastAPI, HTTPException, UploadFile, File, Form
14
  from fastapi.responses import JSONResponse
15
+ from typing import Dict, Any, Optional, Tuple
16
 
17
  # Configure logging
18
  logging.basicConfig(level=logging.INFO)
 
31
  }
32
  error_message = None
33
 
34
+ # Model instances
35
  stt_processor = None
36
  stt_model = None
37
+ mt_model = None
38
+ mt_tokenizer = None
39
+ tts_model = None
40
+ tts_tokenizer = None
41
 
42
+ # Define the valid languages and mappings
43
  LANGUAGE_MAPPING = {
44
  "English": "eng",
45
  "Tagalog": "tgl",
 
49
  "Pangasinan": "pag"
50
  }
51
 
52
+ NLLB_LANGUAGE_CODES = {
53
+ "eng": "eng_Latn",
54
+ "tgl": "tgl_Latn",
55
+ "ceb": "ceb_Latn",
56
+ "ilo": "ilo_Latn",
57
+ "war": "war_Latn",
58
+ "pag": "pag_Latn"
59
+ }
60
+
61
  # Function to load models in background
62
  def load_models_task():
63
+ global models_loaded, loading_in_progress, model_status, error_message
64
+ global stt_processor, stt_model, mt_model, mt_tokenizer, tts_model, tts_tokenizer
65
 
66
  try:
67
  loading_in_progress = True
 
85
  error_message = f"STT model loading failed: {str(e)}"
86
  return
87
 
88
+ # Load MT model
89
+ logger.info("Starting to load MT model...")
90
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
91
+
92
+ try:
93
+ logger.info("Loading NLLB-200-distilled-600M model...")
94
+ model_status["mt"] = "loading"
95
+ mt_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
96
+ mt_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M", clean_up_tokenization_spaces=True)
97
+ mt_model.to(device)
98
+ logger.info("MT model loaded successfully")
99
+ model_status["mt"] = "loaded"
100
+ except Exception as e:
101
+ logger.error(f"Failed to load MT model: {str(e)}")
102
+ model_status["mt"] = "failed"
103
+ error_message = f"MT model loading failed: {str(e)}"
104
+ return
105
+
106
+ # Load TTS model (default to Tagalog, will be updated by /update-languages)
107
+ logger.info("Starting to load TTS model...")
108
+ from transformers import VitsModel, AutoTokenizer
109
+
110
+ try:
111
+ logger.info("Loading MMS-TTS model for Tagalog...")
112
+ model_status["tts"] = "loading"
113
+ tts_model = VitsModel.from_pretrained("facebook/mms-tts-tgl")
114
+ tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-tgl", clean_up_tokenization_spaces=True)
115
+ tts_model.to(device)
116
+ logger.info("TTS model loaded successfully")
117
+ model_status["tts"] = "loaded"
118
+ except Exception as e:
119
+ logger.error(f"Failed to load TTS model: {str(e)}")
120
+ # Fallback to English TTS if the target language fails
121
+ try:
122
+ logger.info("Falling back to MMS-TTS English model...")
123
+ tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
124
+ tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng", clean_up_tokenization_spaces=True)
125
+ tts_model.to(device)
126
+ logger.info("Fallback TTS model loaded successfully")
127
+ model_status["tts"] = "loaded (fallback)"
128
+ except Exception as e2:
129
+ logger.error(f"Failed to load fallback TTS model: {str(e2)}")
130
+ model_status["tts"] = "failed"
131
+ error_message = f"TTS model loading failed: {str(e)} (fallback also failed: {str(e2)})"
132
+ return
133
 
134
  models_loaded = True
135
  logger.info("Model loading completed successfully")
 
176
 
177
  @app.post("/update-languages")
178
  async def update_languages(source_lang: str = Form(...), target_lang: str = Form(...)):
179
+ global tts_model, tts_tokenizer
180
+
181
  if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
182
  raise HTTPException(status_code=400, detail="Invalid language selected")
183
+
184
+ source_code = LANGUAGE_MAPPING[source_lang]
185
+ target_code = LANGUAGE_MAPPING[target_lang]
186
+
187
+ # Update the TTS model based on the target language
188
+ try:
189
+ logger.info(f"Loading MMS-TTS model for {target_code}...")
190
+ from transformers import VitsModel, AutoTokenizer
191
+ tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{target_code}")
192
+ tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{target_code}", clean_up_tokenization_spaces=True)
193
+ device = "cuda" if torch.cuda.is_available() else "cpu"
194
+ tts_model.to(device)
195
+ logger.info(f"TTS model updated to {target_code}")
196
+ model_status["tts"] = "loaded"
197
+ except Exception as e:
198
+ logger.error(f"Failed to load TTS model for {target_code}: {str(e)}")
199
+ try:
200
+ logger.info("Falling back to MMS-TTS English model...")
201
+ tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
202
+ tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng", clean_up_tokenization_spaces=True)
203
+ tts_model.to(device)
204
+ logger.info("Fallback TTS model loaded successfully")
205
+ model_status["tts"] = "loaded (fallback)"
206
+ except Exception as e2:
207
+ logger.error(f"Failed to load fallback TTS model: {str(e2)}")
208
+ model_status["tts"] = "failed"
209
+ error_message = f"TTS model loading failed: {str(e)} (fallback also failed: {str(e2)})"
210
+
211
  logger.info(f"Updating languages: {source_lang} → {target_lang}")
212
  return {"status": f"Languages updated to {source_lang} → {target_lang}"}
213
 
214
  @app.post("/translate-text")
215
  async def translate_text(text: str = Form(...), source_lang: str = Form(...), target_lang: str = Form(...)):
216
+ """Endpoint to translate text and convert to speech"""
217
+ global mt_model, mt_tokenizer, tts_model, tts_tokenizer
218
+
219
  if not text:
220
  raise HTTPException(status_code=400, detail="No text provided")
221
  if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
 
223
 
224
  logger.info(f"Translate-text requested: {text} from {source_lang} to {target_lang}")
225
  request_id = str(uuid.uuid4())
226
+
227
+ # Translate the text
228
+ source_code = LANGUAGE_MAPPING[source_lang]
229
+ target_code = LANGUAGE_MAPPING[target_lang]
230
+ translated_text = "Translation not available"
231
+
232
+ if model_status["mt"] == "loaded" and mt_model is not None and mt_tokenizer is not None:
233
+ try:
234
+ source_nllb_code = NLLB_LANGUAGE_CODES[source_code]
235
+ target_nllb_code = NLLB_LANGUAGE_CODES[target_code]
236
+ mt_tokenizer.src_lang = source_nllb_code
237
+ device = "cuda" if torch.cuda.is_available() else "cpu"
238
+ inputs = mt_tokenizer(text, return_tensors="pt", clean_up_tokenization_spaces=True).to(device)
239
+ with torch.no_grad():
240
+ generated_tokens = mt_model.generate(
241
+ **inputs,
242
+ forced_bos_token_id=mt_tokenizer.convert_tokens_to_ids(target_nllb_code),
243
+ max_length=448
244
+ )
245
+ translated_text = mt_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
246
+ logger.info(f"Translation completed: {translated_text}")
247
+ except Exception as e:
248
+ logger.error(f"Error during translation: {str(e)}")
249
+ translated_text = f"Translation failed: {str(e)}"
250
+ else:
251
+ logger.warning("MT model not loaded, skipping translation")
252
+
253
+ # Convert translated text to speech
254
+ output_audio = None
255
+ if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
256
+ try:
257
+ inputs = tts_tokenizer(translated_text, return_tensors="pt", clean_up_tokenization_spaces=True).to(device)
258
+ with torch.no_grad():
259
+ output = tts_model(**inputs)
260
+ speech = output.waveform.cpu().numpy().squeeze()
261
+ speech = (speech * 32767).astype(np.int16)
262
+ output_audio = (tts_model.config.sampling_rate, speech.tolist())
263
+ logger.info("TTS conversion completed")
264
+ except Exception as e:
265
+ logger.error(f"Error during TTS conversion: {str(e)}")
266
+ output_audio = None
267
+
268
  return {
269
  "request_id": request_id,
270
+ "status": "completed",
271
+ "message": "Translation and TTS completed (or partially completed).",
272
  "source_text": text,
273
+ "translated_text": translated_text,
274
+ "output_audio": output_audio
275
  }
276
 
277
  @app.post("/translate-audio")
278
  async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form(...), target_lang: str = Form(...)):
279
+ """Endpoint to transcribe, translate, and convert audio to speech"""
280
+ global stt_processor, stt_model, mt_model, mt_tokenizer, tts_model, tts_tokenizer
281
 
282
  if not audio:
283
  raise HTTPException(status_code=400, detail="No audio file provided")
 
313
  logger.info(f"Resampling audio from {sample_rate} Hz to 16000 Hz")
314
  waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000)
315
 
316
+ # Process the audio with Whisper (STT)
317
  device = "cuda" if torch.cuda.is_available() else "cpu"
318
  logger.info(f"Using device: {device}")
319
  inputs = stt_processor(waveform, sampling_rate=16000, return_tensors="pt").to(device)
 
321
  with torch.no_grad():
322
  generated_ids = stt_model.generate(**inputs)
323
  transcription = stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
324
  logger.info(f"Transcription completed: {transcription}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
 
326
+ # Translate the transcribed text
327
+ source_code = LANGUAGE_MAPPING[source_lang]
328
+ target_code = LANGUAGE_MAPPING[target_lang]
329
+ translated_text = "Translation not available"
330
+
331
+ if model_status["mt"] == "loaded" and mt_model is not None and mt_tokenizer is not None:
332
+ try:
333
+ source_nllb_code = NLLB_LANGUAGE_CODES[source_code]
334
+ target_nllb_code = NLLB_LANGUAGE_CODES[target_code]
335
+ mt_tokenizer.src_lang = source_nllb_code
336
+ inputs = mt_tokenizer(transcription, return_tensors="pt", clean_up_tokenization_spaces=True).to(device)
337
+ with torch.no_grad():
338
+ generated_tokens = mt_model.generate(
339
+ **inputs,
340
+ forced_bos_token_id=mt_tokenizer.convert_tokens_to_ids(target_nllb_code),
341
+ max_length=448
342
+ )
343
+ translated_text = mt_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
344
+ logger.info(f"Translation completed: {translated_text}")
345
+ except Exception as e:
346
+ logger.error(f"Error during translation: {str(e)}")
347
+ translated_text = f"Translation failed: {str(e)}"
348
+ else:
349
+ logger.warning("MT model not loaded, skipping translation")
350
+
351
+ # Convert translated text to speech
352
+ output_audio = None
353
+ if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
354
+ try:
355
+ inputs = tts_tokenizer(translated_text, return_tensors="pt", clean_up_tokenization_spaces=True).to(device)
356
+ with torch.no_grad():
357
+ output = tts_model(**inputs)
358
+ speech = output.waveform.cpu().numpy().squeeze