Jerich commited on
Commit
f12e9dc
·
verified ·
1 Parent(s): 5b5fc47

Modified the _initialize_tts_model method to include the clean_up_tokenization_spaces parameter; Added logging configuration to configure the logging level for transformers in app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -15
app.py CHANGED
@@ -24,10 +24,13 @@ from typing import Optional, Tuple, Dict
24
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException
25
  from fastapi.responses import JSONResponse
26
  import tempfile
 
 
 
 
27
 
28
  app = FastAPI(title="Talklas API")
29
 
30
- # Rest of your code remains the same
31
  class TalklasTranslator:
32
  LANGUAGE_MAPPING = {
33
  "English": "eng",
@@ -79,19 +82,26 @@ class TalklasTranslator:
79
  self.mt_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
80
  self.mt_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
81
  self.mt_model.to(self.device)
 
82
  except Exception as e:
83
  raise RuntimeError(f"MT model initialization failed: {e}")
84
 
85
  def _initialize_tts_model(self):
86
  try:
87
  self.tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{self.target_lang}")
88
- self.tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{self.target_lang}")
 
 
 
89
  self.tts_model.to(self.device)
90
  print(f"Loaded TTS model facebook/mms-tts-{self.target_lang} successfully")
91
  except Exception:
92
  print(f"Failed to load facebook/mms-tts-{self.target_lang}, falling back to English TTS")
93
  self.tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
94
- self.tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
 
 
 
95
  self.tts_model.to(self.device)
96
  print("Loaded fallback TTS model facebook/mms-tts-eng successfully")
97
 
@@ -113,18 +123,18 @@ class TalklasTranslator:
113
  transcription = self.stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
114
  return transcription
115
 
116
- def translate_text(self, text: str) -> str:
117
- source_code = self.NLLB_LANGUAGE_CODES[self.source_lang]
118
- target_code = self.NLLB_LANGUAGE_CODES[self.target_lang]
119
- self.mt_tokenizer.src_lang = source_code
120
- inputs = self.mt_tokenizer(text, return_tensors="pt", clean_up_tokenization_spaces=True).to(self.device)
121
- with torch.no_grad():
122
- generated_tokens = self.mt_model.generate(
123
- **inputs,
124
- forced_bos_token_id=self.mt_tokenizer.convert_tokens_to_ids(target_code),
125
- max_length=448
126
- )
127
- return self.mt_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
128
 
129
  def text_to_speech(self, text: str) -> Tuple[int, np.ndarray]:
130
  inputs = self.tts_tokenizer(text, return_tensors="pt", clean_up_tokenization_spaces=True).to(self.device)
 
24
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException
25
  from fastapi.responses import JSONResponse
26
  import tempfile
27
+ import logging
28
+
29
+ # Configure transformers logging to reduce verbosity
30
+ logging.getLogger("transformers").setLevel(logging.ERROR)
31
 
32
  app = FastAPI(title="Talklas API")
33
 
 
34
  class TalklasTranslator:
35
  LANGUAGE_MAPPING = {
36
  "English": "eng",
 
82
  self.mt_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
83
  self.mt_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
84
  self.mt_model.to(self.device)
85
+ print("Loaded NLLB translation model successfully")
86
  except Exception as e:
87
  raise RuntimeError(f"MT model initialization failed: {e}")
88
 
89
  def _initialize_tts_model(self):
90
  try:
91
  self.tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{self.target_lang}")
92
+ self.tts_tokenizer = AutoTokenizer.from_pretrained(
93
+ f"facebook/mms-tts-{self.target_lang}",
94
+ clean_up_tokenization_spaces=True
95
+ )
96
  self.tts_model.to(self.device)
97
  print(f"Loaded TTS model facebook/mms-tts-{self.target_lang} successfully")
98
  except Exception:
99
  print(f"Failed to load facebook/mms-tts-{self.target_lang}, falling back to English TTS")
100
  self.tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
101
+ self.tts_tokenizer = AutoTokenizer.from_pretrained(
102
+ "facebook/mms-tts-eng",
103
+ clean_up_tokenization_spaces=True
104
+ )
105
  self.tts_model.to(self.device)
106
  print("Loaded fallback TTS model facebook/mms-tts-eng successfully")
107
 
 
123
  transcription = self.stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
124
  return transcription
125
 
126
+ def translate_text(self, text: str) -> str:
127
+ source_code = self.NLLB_LANGUAGE_CODES[self.source_lang]
128
+ target_code = self.NLLB_LANGUAGE_CODES[self.target_lang]
129
+ self.mt_tokenizer.src_lang = source_code
130
+ inputs = self.mt_tokenizer(text, return_tensors="pt", clean_up_tokenization_spaces=True).to(self.device)
131
+ with torch.no_grad():
132
+ generated_tokens = self.mt_model.generate(
133
+ **inputs,
134
+ forced_bos_token_id=self.mt_tokenizer.convert_tokens_to_ids(target_code),
135
+ max_length=448
136
+ )
137
+ return self.mt_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
138
 
139
  def text_to_speech(self, text: str) -> Tuple[int, np.ndarray]:
140
  inputs = self.tts_tokenizer(text, return_tensors="pt", clean_up_tokenization_spaces=True).to(self.device)