Jerich commited on
Commit
6a1bf6c
·
verified ·
1 Parent(s): e868e42

Fix MT model loading: Revert to nllb-200-distilled-600M and add fallback

Browse files

- Reverted translation model to facebook/nllb-200-distilled-600M (facebook/nllb-200-distilled-200M does not exist)
- Added fallback mechanism in _initialize_mt_model to handle model loading failures
- Modified translate_text to return source text if the translation model is not loaded
- Added initialization of mt_model and mt_tokenizer as None in __init__

Files changed (1) hide show
  1. app.py +12 -4
app.py CHANGED
@@ -55,6 +55,8 @@ class TalklasTranslator:
55
  self.source_lang = source_lang
56
  self.target_lang = target_lang
57
  self.sample_rate = 16000
 
 
58
  self._initialize_stt_model()
59
  self._initialize_mt_model()
60
  self._initialize_tts_model()
@@ -72,16 +74,19 @@ class TalklasTranslator:
72
 
73
  def _initialize_mt_model(self):
74
  try:
75
- print("Trying to load facebook/nllb-200-distilled-200M...")
76
- self.mt_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-200M")
77
  self.mt_tokenizer = AutoTokenizer.from_pretrained(
78
- "facebook/nllb-200-distilled-200M",
79
  clean_up_tokenization_spaces=True
80
  )
81
  self.mt_model.to(self.device)
82
  print("Loaded NLLB translation model successfully")
83
  except Exception as e:
84
- raise RuntimeError(f"MT model initialization failed: {e}")
 
 
 
85
 
86
  def _initialize_tts_model(self):
87
  try:
@@ -124,6 +129,9 @@ class TalklasTranslator:
124
  return transcription
125
 
126
  def translate_text(self, text: str) -> str:
 
 
 
127
  source_code = self.NLLB_LANGUAGE_CODES[self.source_lang]
128
  target_code = self.NLLB_LANGUAGE_CODES[self.target_lang]
129
  self.mt_tokenizer.src_lang = source_code
 
55
  self.source_lang = source_lang
56
  self.target_lang = target_lang
57
  self.sample_rate = 16000
58
+ self.mt_model = None # Initialize as None
59
+ self.mt_tokenizer = None # Initialize as None
60
  self._initialize_stt_model()
61
  self._initialize_mt_model()
62
  self._initialize_tts_model()
 
74
 
75
  def _initialize_mt_model(self):
76
  try:
77
+ print("Trying to load facebook/nllb-200-distilled-600M...")
78
+ self.mt_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
79
  self.mt_tokenizer = AutoTokenizer.from_pretrained(
80
+ "facebook/nllb-200-distilled-600M",
81
  clean_up_tokenization_spaces=True
82
  )
83
  self.mt_model.to(self.device)
84
  print("Loaded NLLB translation model successfully")
85
  except Exception as e:
86
+ print(f"Failed to load facebook/nllb-200-distilled-600M: {e}")
87
+ print("Translation model not loaded, translation will return source text as a fallback")
88
+ self.mt_model = None
89
+ self.mt_tokenizer = None
90
 
91
  def _initialize_tts_model(self):
92
  try:
 
129
  return transcription
130
 
131
  def translate_text(self, text: str) -> str:
132
+ if self.mt_model is None or self.mt_tokenizer is None:
133
+ print("Translation model not loaded, returning source text as fallback")
134
+ return text
135
  source_code = self.NLLB_LANGUAGE_CODES[self.source_lang]
136
  target_code = self.NLLB_LANGUAGE_CODES[self.target_lang]
137
  self.mt_tokenizer.src_lang = source_code