Jerich commited on
Commit
e656c37
·
verified ·
1 Parent(s): 989a3f5

Fix MT error by removing clean_up_tokenization_spaces and address STT warnings

Browse files
Files changed (1) hide show
  1. app.py +14 -10
app.py CHANGED
@@ -93,7 +93,7 @@ def load_models_task():
93
  logger.info("Loading NLLB-200-distilled-600M model...")
94
  model_status["mt"] = "loading"
95
  mt_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
96
- mt_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M", clean_up_tokenization_spaces=True)
97
  mt_model.to(device)
98
  logger.info("MT model loaded successfully")
99
  model_status["mt"] = "loaded"
@@ -111,7 +111,7 @@ def load_models_task():
111
  logger.info("Loading MMS-TTS model for Tagalog...")
112
  model_status["tts"] = "loading"
113
  tts_model = VitsModel.from_pretrained("facebook/mms-tts-tgl")
114
- tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-tgl", clean_up_tokenization_spaces=True)
115
  tts_model.to(device)
116
  logger.info("TTS model loaded successfully")
117
  model_status["tts"] = "loaded"
@@ -121,7 +121,7 @@ def load_models_task():
121
  try:
122
  logger.info("Falling back to MMS-TTS English model...")
123
  tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
124
- tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng", clean_up_tokenization_spaces=True)
125
  tts_model.to(device)
126
  logger.info("Fallback TTS model loaded successfully")
127
  model_status["tts"] = "loaded (fallback)"
@@ -189,7 +189,7 @@ async def update_languages(source_lang: str = Form(...), target_lang: str = Form
189
  logger.info(f"Loading MMS-TTS model for {target_code}...")
190
  from transformers import VitsModel, AutoTokenizer
191
  tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{target_code}")
192
- tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{target_code}", clean_up_tokenization_spaces=True)
193
  device = "cuda" if torch.cuda.is_available() else "cpu"
194
  tts_model.to(device)
195
  logger.info(f"TTS model updated to {target_code}")
@@ -199,7 +199,7 @@ async def update_languages(source_lang: str = Form(...), target_lang: str = Form
199
  try:
200
  logger.info("Falling back to MMS-TTS English model...")
201
  tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
202
- tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng", clean_up_tokenization_spaces=True)
203
  tts_model.to(device)
204
  logger.info("Fallback TTS model loaded successfully")
205
  model_status["tts"] = "loaded (fallback)"
@@ -235,7 +235,7 @@ async def translate_text(text: str = Form(...), source_lang: str = Form(...), ta
235
  target_nllb_code = NLLB_LANGUAGE_CODES[target_code]
236
  mt_tokenizer.src_lang = source_nllb_code
237
  device = "cuda" if torch.cuda.is_available() else "cpu"
238
- inputs = mt_tokenizer(text, return_tensors="pt", clean_up_tokenization_spaces=True).to(device)
239
  with torch.no_grad():
240
  generated_tokens = mt_model.generate(
241
  **inputs,
@@ -254,7 +254,7 @@ async def translate_text(text: str = Form(...), source_lang: str = Form(...), ta
254
  output_audio = None
255
  if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
256
  try:
257
- inputs = tts_tokenizer(translated_text, return_tensors="pt", clean_up_tokenization_spaces=True).to(device)
258
  with torch.no_grad():
259
  output = tts_model(**inputs)
260
  speech = output.waveform.cpu().numpy().squeeze()
@@ -322,7 +322,11 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
322
  inputs = stt_processor(waveform, sampling_rate=16000, return_tensors="pt").to(device)
323
  logger.info("Audio processed, generating transcription...")
324
  with torch.no_grad():
325
- generated_ids = stt_model.generate(**inputs)
 
 
 
 
326
  transcription = stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
327
  logger.info(f"Transcription completed: {transcription}")
328
 
@@ -335,7 +339,7 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
335
  source_nllb_code = NLLB_LANGUAGE_CODES[source_code]
336
  target_nllb_code = NLLB_LANGUAGE_CODES[target_code]
337
  mt_tokenizer.src_lang = source_nllb_code
338
- inputs = mt_tokenizer(transcription, return_tensors="pt", clean_up_tokenization_spaces=True).to(device)
339
  with torch.no_grad():
340
  generated_tokens = mt_model.generate(
341
  **inputs,
@@ -353,7 +357,7 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
353
  # Step 3: Convert translated text to speech (TTS)
354
  if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
355
  try:
356
- inputs = tts_tokenizer(translated_text, return_tensors="pt", clean_up_tokenization_spaces=True).to(device)
357
  with torch.no_grad():
358
  output = tts_model(**inputs)
359
  speech = output.waveform.cpu().numpy().squeeze()
 
93
  logger.info("Loading NLLB-200-distilled-600M model...")
94
  model_status["mt"] = "loading"
95
  mt_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
96
+ mt_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
97
  mt_model.to(device)
98
  logger.info("MT model loaded successfully")
99
  model_status["mt"] = "loaded"
 
111
  logger.info("Loading MMS-TTS model for Tagalog...")
112
  model_status["tts"] = "loading"
113
  tts_model = VitsModel.from_pretrained("facebook/mms-tts-tgl")
114
+ tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-tgl")
115
  tts_model.to(device)
116
  logger.info("TTS model loaded successfully")
117
  model_status["tts"] = "loaded"
 
121
  try:
122
  logger.info("Falling back to MMS-TTS English model...")
123
  tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
124
+ tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
125
  tts_model.to(device)
126
  logger.info("Fallback TTS model loaded successfully")
127
  model_status["tts"] = "loaded (fallback)"
 
189
  logger.info(f"Loading MMS-TTS model for {target_code}...")
190
  from transformers import VitsModel, AutoTokenizer
191
  tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{target_code}")
192
+ tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{target_code}")
193
  device = "cuda" if torch.cuda.is_available() else "cpu"
194
  tts_model.to(device)
195
  logger.info(f"TTS model updated to {target_code}")
 
199
  try:
200
  logger.info("Falling back to MMS-TTS English model...")
201
  tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
202
+ tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
203
  tts_model.to(device)
204
  logger.info("Fallback TTS model loaded successfully")
205
  model_status["tts"] = "loaded (fallback)"
 
235
  target_nllb_code = NLLB_LANGUAGE_CODES[target_code]
236
  mt_tokenizer.src_lang = source_nllb_code
237
  device = "cuda" if torch.cuda.is_available() else "cpu"
238
+ inputs = mt_tokenizer(text, return_tensors="pt").to(device)
239
  with torch.no_grad():
240
  generated_tokens = mt_model.generate(
241
  **inputs,
 
254
  output_audio = None
255
  if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
256
  try:
257
+ inputs = tts_tokenizer(translated_text, return_tensors="pt").to(device)
258
  with torch.no_grad():
259
  output = tts_model(**inputs)
260
  speech = output.waveform.cpu().numpy().squeeze()
 
322
  inputs = stt_processor(waveform, sampling_rate=16000, return_tensors="pt").to(device)
323
  logger.info("Audio processed, generating transcription...")
324
  with torch.no_grad():
325
+ generated_ids = stt_model.generate(
326
+ **inputs,
327
+ language="en", # Explicitly set language to English
328
+ return_attention_mask=True # Generate attention mask
329
+ )
330
  transcription = stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
331
  logger.info(f"Transcription completed: {transcription}")
332
 
 
339
  source_nllb_code = NLLB_LANGUAGE_CODES[source_code]
340
  target_nllb_code = NLLB_LANGUAGE_CODES[target_code]
341
  mt_tokenizer.src_lang = source_nllb_code
342
+ inputs = mt_tokenizer(transcription, return_tensors="pt").to(device)
343
  with torch.no_grad():
344
  generated_tokens = mt_model.generate(
345
  **inputs,
 
357
  # Step 3: Convert translated text to speech (TTS)
358
  if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
359
  try:
360
+ inputs = tts_tokenizer(translated_text, return_tensors="pt").to(device)
361
  with torch.no_grad():
362
  output = tts_model(**inputs)
363
  speech = output.waveform.cpu().numpy().squeeze()