Krishna086 commited on
Commit
83c5c51
·
verified ·
1 Parent(s): 145ef9b

Update translation.py

Browse files
Files changed (1) hide show
  1. translation.py +10 -3
translation.py CHANGED
@@ -8,6 +8,7 @@ LANGUAGES = {
8
  "ar": ("العربية", "Arabic"), "ru": ("Русский", "Russian"), "ja": ("日本語", "Japanese")
9
  }
10
 
 
11
  @st.cache_resource
12
  def _load_model_pair(source_lang, target_lang):
13
  try:
@@ -18,6 +19,7 @@ def _load_model_pair(source_lang, target_lang):
18
  except Exception:
19
  return None, None
20
 
 
21
  @st.cache_resource
22
  def _load_all_models():
23
  models = {}
@@ -29,6 +31,12 @@ def _load_all_models():
29
 
30
  all_models = _load_all_models()
31
 
 
 
 
 
 
 
32
  def load_model(source_lang, target_lang):
33
  if source_lang == target_lang:
34
  return _load_default_model()
@@ -47,12 +55,10 @@ def load_model(source_lang, target_lang):
47
  en_to_tgt_tokenizer, en_to_tgt_model = all_models.get(("en", target_lang), _load_model_pair("en", target_lang)) or _load_default_model()
48
  return en_to_tgt_tokenizer.decode(en_to_tgt_model.generate(**en_to_tgt_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
49
  return en_text
50
- class CombinedModel:
51
- def generate(self, **kwargs):
52
- return torch.tensor([combined_translate(tokenizer.decode(x, skip_special_tokens=True)) for x in kwargs['input_ids']])
53
  tokenizer, _ = _load_default_model()
54
  return tokenizer, CombinedModel()
55
 
 
56
  @st.cache_resource
57
  def _load_default_model():
58
  model_name = "Helsinki-NLP/opus-mt-en-hi"
@@ -60,6 +66,7 @@ def _load_default_model():
60
  model = MarianMTModel.from_pretrained(model_name)
61
  return tokenizer, model
62
 
 
63
  def translate(text, source_lang, target_lang):
64
  if not text:
65
  return ""
 
8
  "ar": ("العربية", "Arabic"), "ru": ("Русский", "Russian"), "ja": ("日本語", "Japanese")
9
  }
10
 
11
+ # Cache resource to load a specific translation model pair
12
  @st.cache_resource
13
  def _load_model_pair(source_lang, target_lang):
14
  try:
 
19
  except Exception:
20
  return None, None
21
 
22
+ # Cache resource to load all possible model combinations
23
  @st.cache_resource
24
  def _load_all_models():
25
  models = {}
 
31
 
32
  all_models = _load_all_models()
33
 
34
+ # Class to handle combined translation through English pivot
35
+ class CombinedModel:
36
+ def generate(self, **kwargs):
37
+ return torch.tensor([combined_translate(tokenizer.decode(x, skip_special_tokens=True)) for x in kwargs['input_ids']])
38
+
39
+ # Function to load appropriate translation model
40
  def load_model(source_lang, target_lang):
41
  if source_lang == target_lang:
42
  return _load_default_model()
 
55
  en_to_tgt_tokenizer, en_to_tgt_model = all_models.get(("en", target_lang), _load_model_pair("en", target_lang)) or _load_default_model()
56
  return en_to_tgt_tokenizer.decode(en_to_tgt_model.generate(**en_to_tgt_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
57
  return en_text
 
 
 
58
  tokenizer, _ = _load_default_model()
59
  return tokenizer, CombinedModel()
60
 
61
+ # Cache resource to load default translation model
62
  @st.cache_resource
63
  def _load_default_model():
64
  model_name = "Helsinki-NLP/opus-mt-en-hi"
 
66
  model = MarianMTModel.from_pretrained(model_name)
67
  return tokenizer, model
68
 
69
+ # Function to perform the translation
70
  def translate(text, source_lang, target_lang):
71
  if not text:
72
  return ""