Jerich commited on
Commit
5b5fc47
·
verified ·
1 Parent(s): 3cc49a2

Modified the tokenization step to include clean_up_tokenization_spaces=True; Added clean_up_tokenization_spaces=True in the text_to_speech method; Added a print statement to confirm the TTS model is loaded

Browse files
Files changed (1) hide show
  1. app.py +16 -13
app.py CHANGED
@@ -87,10 +87,13 @@ class TalklasTranslator:
87
  self.tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{self.target_lang}")
88
  self.tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{self.target_lang}")
89
  self.tts_model.to(self.device)
 
90
  except Exception:
 
91
  self.tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
92
  self.tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
93
  self.tts_model.to(self.device)
 
94
 
95
  def update_languages(self, source_lang: str, target_lang: str):
96
  self.source_lang = source_lang
@@ -110,21 +113,21 @@ class TalklasTranslator:
110
  transcription = self.stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
111
  return transcription
112
 
113
- def translate_text(self, text: str) -> str:
114
- source_code = self.NLLB_LANGUAGE_CODES[self.source_lang]
115
- target_code = self.NLLB_LANGUAGE_CODES[self.target_lang]
116
- self.mt_tokenizer.src_lang = source_code
117
- inputs = self.mt_tokenizer(text, return_tensors="pt").to(self.device)
118
- with torch.no_grad():
119
- generated_tokens = self.mt_model.generate(
120
- **inputs,
121
- forced_bos_token_id=self.mt_tokenizer.convert_tokens_to_ids(target_code),
122
- max_length=448
123
- )
124
- return self.mt_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
125
 
126
  def text_to_speech(self, text: str) -> Tuple[int, np.ndarray]:
127
- inputs = self.tts_tokenizer(text, return_tensors="pt").to(self.device)
128
  with torch.no_grad():
129
  output = self.tts_model(**inputs)
130
  speech = output.waveform.cpu().numpy().squeeze()
 
87
  self.tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{self.target_lang}")
88
  self.tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{self.target_lang}")
89
  self.tts_model.to(self.device)
90
+ print(f"Loaded TTS model facebook/mms-tts-{self.target_lang} successfully")
91
  except Exception:
92
+ print(f"Failed to load facebook/mms-tts-{self.target_lang}, falling back to English TTS")
93
  self.tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
94
  self.tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
95
  self.tts_model.to(self.device)
96
+ print("Loaded fallback TTS model facebook/mms-tts-eng successfully")
97
 
98
  def update_languages(self, source_lang: str, target_lang: str):
99
  self.source_lang = source_lang
 
113
  transcription = self.stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
114
  return transcription
115
 
116
+ def translate_text(self, text: str) -> str:
117
+ source_code = self.NLLB_LANGUAGE_CODES[self.source_lang]
118
+ target_code = self.NLLB_LANGUAGE_CODES[self.target_lang]
119
+ self.mt_tokenizer.src_lang = source_code
120
+ inputs = self.mt_tokenizer(text, return_tensors="pt", clean_up_tokenization_spaces=True).to(self.device)
121
+ with torch.no_grad():
122
+ generated_tokens = self.mt_model.generate(
123
+ **inputs,
124
+ forced_bos_token_id=self.mt_tokenizer.convert_tokens_to_ids(target_code),
125
+ max_length=448
126
+ )
127
+ return self.mt_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
128
 
129
  def text_to_speech(self, text: str) -> Tuple[int, np.ndarray]:
130
+ inputs = self.tts_tokenizer(text, return_tensors="pt", clean_up_tokenization_spaces=True).to(self.device)
131
  with torch.no_grad():
132
  output = self.tts_model(**inputs)
133
  speech = output.waveform.cpu().numpy().squeeze()