Spaces:
Paused
Paused
Modified the tokenization step to include clean_up_tokenization_spaces=True; Added clean_up_tokenization_spaces=True in the text_to_speech method; Added a print statement to confirm the TTS model is loaded
Browse files
app.py
CHANGED
@@ -87,10 +87,13 @@ class TalklasTranslator:
|
|
87 |
self.tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{self.target_lang}")
|
88 |
self.tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{self.target_lang}")
|
89 |
self.tts_model.to(self.device)
|
|
|
90 |
except Exception:
|
|
|
91 |
self.tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
|
92 |
self.tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
|
93 |
self.tts_model.to(self.device)
|
|
|
94 |
|
95 |
def update_languages(self, source_lang: str, target_lang: str):
|
96 |
self.source_lang = source_lang
|
@@ -110,21 +113,21 @@ class TalklasTranslator:
|
|
110 |
transcription = self.stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
111 |
return transcription
|
112 |
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
|
126 |
def text_to_speech(self, text: str) -> Tuple[int, np.ndarray]:
|
127 |
-
inputs = self.tts_tokenizer(text, return_tensors="pt").to(self.device)
|
128 |
with torch.no_grad():
|
129 |
output = self.tts_model(**inputs)
|
130 |
speech = output.waveform.cpu().numpy().squeeze()
|
|
|
87 |
self.tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{self.target_lang}")
|
88 |
self.tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{self.target_lang}")
|
89 |
self.tts_model.to(self.device)
|
90 |
+
print(f"Loaded TTS model facebook/mms-tts-{self.target_lang} successfully")
|
91 |
except Exception:
|
92 |
+
print(f"Failed to load facebook/mms-tts-{self.target_lang}, falling back to English TTS")
|
93 |
self.tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
|
94 |
self.tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
|
95 |
self.tts_model.to(self.device)
|
96 |
+
print("Loaded fallback TTS model facebook/mms-tts-eng successfully")
|
97 |
|
98 |
def update_languages(self, source_lang: str, target_lang: str):
|
99 |
self.source_lang = source_lang
|
|
|
113 |
transcription = self.stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
114 |
return transcription
|
115 |
|
116 |
+
def translate_text(self, text: str) -> str:
|
117 |
+
source_code = self.NLLB_LANGUAGE_CODES[self.source_lang]
|
118 |
+
target_code = self.NLLB_LANGUAGE_CODES[self.target_lang]
|
119 |
+
self.mt_tokenizer.src_lang = source_code
|
120 |
+
inputs = self.mt_tokenizer(text, return_tensors="pt", clean_up_tokenization_spaces=True).to(self.device)
|
121 |
+
with torch.no_grad():
|
122 |
+
generated_tokens = self.mt_model.generate(
|
123 |
+
**inputs,
|
124 |
+
forced_bos_token_id=self.mt_tokenizer.convert_tokens_to_ids(target_code),
|
125 |
+
max_length=448
|
126 |
+
)
|
127 |
+
return self.mt_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
|
128 |
|
129 |
def text_to_speech(self, text: str) -> Tuple[int, np.ndarray]:
|
130 |
+
inputs = self.tts_tokenizer(text, return_tensors="pt", clean_up_tokenization_spaces=True).to(self.device)
|
131 |
with torch.no_grad():
|
132 |
output = self.tts_model(**inputs)
|
133 |
speech = output.waveform.cpu().numpy().squeeze()
|