Spaces:
Paused
Paused
Fix MT error by removing clean_up_tokenization_spaces and address STT warnings
Browse files
app.py
CHANGED
@@ -93,7 +93,7 @@ def load_models_task():
|
|
93 |
logger.info("Loading NLLB-200-distilled-600M model...")
|
94 |
model_status["mt"] = "loading"
|
95 |
mt_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
|
96 |
-
mt_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M"
|
97 |
mt_model.to(device)
|
98 |
logger.info("MT model loaded successfully")
|
99 |
model_status["mt"] = "loaded"
|
@@ -111,7 +111,7 @@ def load_models_task():
|
|
111 |
logger.info("Loading MMS-TTS model for Tagalog...")
|
112 |
model_status["tts"] = "loading"
|
113 |
tts_model = VitsModel.from_pretrained("facebook/mms-tts-tgl")
|
114 |
-
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-tgl"
|
115 |
tts_model.to(device)
|
116 |
logger.info("TTS model loaded successfully")
|
117 |
model_status["tts"] = "loaded"
|
@@ -121,7 +121,7 @@ def load_models_task():
|
|
121 |
try:
|
122 |
logger.info("Falling back to MMS-TTS English model...")
|
123 |
tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
|
124 |
-
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng"
|
125 |
tts_model.to(device)
|
126 |
logger.info("Fallback TTS model loaded successfully")
|
127 |
model_status["tts"] = "loaded (fallback)"
|
@@ -189,7 +189,7 @@ async def update_languages(source_lang: str = Form(...), target_lang: str = Form
|
|
189 |
logger.info(f"Loading MMS-TTS model for {target_code}...")
|
190 |
from transformers import VitsModel, AutoTokenizer
|
191 |
tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{target_code}")
|
192 |
-
tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{target_code}"
|
193 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
194 |
tts_model.to(device)
|
195 |
logger.info(f"TTS model updated to {target_code}")
|
@@ -199,7 +199,7 @@ async def update_languages(source_lang: str = Form(...), target_lang: str = Form
|
|
199 |
try:
|
200 |
logger.info("Falling back to MMS-TTS English model...")
|
201 |
tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
|
202 |
-
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng"
|
203 |
tts_model.to(device)
|
204 |
logger.info("Fallback TTS model loaded successfully")
|
205 |
model_status["tts"] = "loaded (fallback)"
|
@@ -235,7 +235,7 @@ async def translate_text(text: str = Form(...), source_lang: str = Form(...), ta
|
|
235 |
target_nllb_code = NLLB_LANGUAGE_CODES[target_code]
|
236 |
mt_tokenizer.src_lang = source_nllb_code
|
237 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
238 |
-
inputs = mt_tokenizer(text, return_tensors="pt"
|
239 |
with torch.no_grad():
|
240 |
generated_tokens = mt_model.generate(
|
241 |
**inputs,
|
@@ -254,7 +254,7 @@ async def translate_text(text: str = Form(...), source_lang: str = Form(...), ta
|
|
254 |
output_audio = None
|
255 |
if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
|
256 |
try:
|
257 |
-
inputs = tts_tokenizer(translated_text, return_tensors="pt"
|
258 |
with torch.no_grad():
|
259 |
output = tts_model(**inputs)
|
260 |
speech = output.waveform.cpu().numpy().squeeze()
|
@@ -322,7 +322,11 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
|
|
322 |
inputs = stt_processor(waveform, sampling_rate=16000, return_tensors="pt").to(device)
|
323 |
logger.info("Audio processed, generating transcription...")
|
324 |
with torch.no_grad():
|
325 |
-
generated_ids = stt_model.generate(
|
|
|
|
|
|
|
|
|
326 |
transcription = stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
327 |
logger.info(f"Transcription completed: {transcription}")
|
328 |
|
@@ -335,7 +339,7 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
|
|
335 |
source_nllb_code = NLLB_LANGUAGE_CODES[source_code]
|
336 |
target_nllb_code = NLLB_LANGUAGE_CODES[target_code]
|
337 |
mt_tokenizer.src_lang = source_nllb_code
|
338 |
-
inputs = mt_tokenizer(transcription, return_tensors="pt"
|
339 |
with torch.no_grad():
|
340 |
generated_tokens = mt_model.generate(
|
341 |
**inputs,
|
@@ -353,7 +357,7 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
|
|
353 |
# Step 3: Convert translated text to speech (TTS)
|
354 |
if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
|
355 |
try:
|
356 |
-
inputs = tts_tokenizer(translated_text, return_tensors="pt"
|
357 |
with torch.no_grad():
|
358 |
output = tts_model(**inputs)
|
359 |
speech = output.waveform.cpu().numpy().squeeze()
|
|
|
93 |
logger.info("Loading NLLB-200-distilled-600M model...")
|
94 |
model_status["mt"] = "loading"
|
95 |
mt_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
|
96 |
+
mt_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
|
97 |
mt_model.to(device)
|
98 |
logger.info("MT model loaded successfully")
|
99 |
model_status["mt"] = "loaded"
|
|
|
111 |
logger.info("Loading MMS-TTS model for Tagalog...")
|
112 |
model_status["tts"] = "loading"
|
113 |
tts_model = VitsModel.from_pretrained("facebook/mms-tts-tgl")
|
114 |
+
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-tgl")
|
115 |
tts_model.to(device)
|
116 |
logger.info("TTS model loaded successfully")
|
117 |
model_status["tts"] = "loaded"
|
|
|
121 |
try:
|
122 |
logger.info("Falling back to MMS-TTS English model...")
|
123 |
tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
|
124 |
+
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
|
125 |
tts_model.to(device)
|
126 |
logger.info("Fallback TTS model loaded successfully")
|
127 |
model_status["tts"] = "loaded (fallback)"
|
|
|
189 |
logger.info(f"Loading MMS-TTS model for {target_code}...")
|
190 |
from transformers import VitsModel, AutoTokenizer
|
191 |
tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{target_code}")
|
192 |
+
tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{target_code}")
|
193 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
194 |
tts_model.to(device)
|
195 |
logger.info(f"TTS model updated to {target_code}")
|
|
|
199 |
try:
|
200 |
logger.info("Falling back to MMS-TTS English model...")
|
201 |
tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
|
202 |
+
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
|
203 |
tts_model.to(device)
|
204 |
logger.info("Fallback TTS model loaded successfully")
|
205 |
model_status["tts"] = "loaded (fallback)"
|
|
|
235 |
target_nllb_code = NLLB_LANGUAGE_CODES[target_code]
|
236 |
mt_tokenizer.src_lang = source_nllb_code
|
237 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
238 |
+
inputs = mt_tokenizer(text, return_tensors="pt").to(device)
|
239 |
with torch.no_grad():
|
240 |
generated_tokens = mt_model.generate(
|
241 |
**inputs,
|
|
|
254 |
output_audio = None
|
255 |
if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
|
256 |
try:
|
257 |
+
inputs = tts_tokenizer(translated_text, return_tensors="pt").to(device)
|
258 |
with torch.no_grad():
|
259 |
output = tts_model(**inputs)
|
260 |
speech = output.waveform.cpu().numpy().squeeze()
|
|
|
322 |
inputs = stt_processor(waveform, sampling_rate=16000, return_tensors="pt").to(device)
|
323 |
logger.info("Audio processed, generating transcription...")
|
324 |
with torch.no_grad():
|
325 |
+
generated_ids = stt_model.generate(
|
326 |
+
**inputs,
|
327 |
+
language="en", # Explicitly set language to English
|
328 |
+
return_attention_mask=True # Generate attention mask
|
329 |
+
)
|
330 |
transcription = stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
331 |
logger.info(f"Transcription completed: {transcription}")
|
332 |
|
|
|
339 |
source_nllb_code = NLLB_LANGUAGE_CODES[source_code]
|
340 |
target_nllb_code = NLLB_LANGUAGE_CODES[target_code]
|
341 |
mt_tokenizer.src_lang = source_nllb_code
|
342 |
+
inputs = mt_tokenizer(transcription, return_tensors="pt").to(device)
|
343 |
with torch.no_grad():
|
344 |
generated_tokens = mt_model.generate(
|
345 |
**inputs,
|
|
|
357 |
# Step 3: Convert translated text to speech (TTS)
|
358 |
if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
|
359 |
try:
|
360 |
+
inputs = tts_tokenizer(translated_text, return_tensors="pt").to(device)
|
361 |
with torch.no_grad():
|
362 |
output = tts_model(**inputs)
|
363 |
speech = output.waveform.cpu().numpy().squeeze()
|