Krishna086 commited on
Commit
71d9b1a
·
verified ·
1 Parent(s): d2b936e

Update translation.py

Browse files
Files changed (1) hide show
  1. translation.py +45 -27
translation.py CHANGED
@@ -8,6 +8,51 @@ LANGUAGES = {
8
  "ar": ("العربية", "Arabic"), "ru": ("Русский", "Russian"), "ja": ("日本語", "Japanese")
9
  }
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  @st.cache_resource
12
  def _load_default_model():
13
  model_name = "Helsinki-NLP/opus-mt-en-hi"
@@ -15,33 +60,6 @@ def _load_default_model():
15
  model = MarianMTModel.from_pretrained(model_name)
16
  return tokenizer, model
17
 
18
- @st.cache_resource
19
- def load_model(source_lang, target_lang):
20
- try:
21
- if source_lang == target_lang:
22
- return _load_default_model()
23
- # Try direct model
24
- model_name = f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}"
25
- try:
26
- tokenizer = MarianTokenizer.from_pretrained(model_name)
27
- model = MarianMTModel.from_pretrained(model_name)
28
- return tokenizer, model
29
- except Exception:
30
- # Pivot through English for non-English pairs
31
- if source_lang != "en" and target_lang != "en":
32
- def combined_translate(text):
33
- en_tokenizer, en_model = load_model(source_lang, "en")
34
- en_text = en_tokenizer.decode(en_model.generate(**en_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
35
- target_tokenizer, target_model = load_model("en", target_lang)
36
- return target_tokenizer.decode(target_model.generate(**target_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
37
- class CombinedModel:
38
- def generate(self, **kwargs):
39
- return torch.tensor([combined_translate(tokenizer.decode(x, skip_special_tokens=True)) for x in kwargs['input_ids']])
40
- return MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-hi"), CombinedModel()
41
- return _load_default_model()
42
- except Exception:
43
- return _load_default_model()
44
-
45
  def translate(text, source_lang, target_lang):
46
  if not text:
47
  return ""
 
8
  "ar": ("العربية", "Arabic"), "ru": ("Русский", "Russian"), "ja": ("日本語", "Japanese")
9
  }
10
 
11
+ @st.cache_resource
12
+ def _load_model_pair(source_lang, target_lang):
13
+ try:
14
+ model_name = f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}"
15
+ tokenizer = MarianTokenizer.from_pretrained(model_name)
16
+ model = MarianMTModel.from_pretrained(model_name)
17
+ return tokenizer, model
18
+ except Exception:
19
+ return None, None
20
+
21
+ @st.cache_resource
22
+ def _load_all_models():
23
+ models = {}
24
+ for src in LANGUAGES.keys():
25
+ for tgt in LANGUAGES.keys():
26
+ if src != tgt:
27
+ models[(src, tgt)] = _load_model_pair(src, tgt)
28
+ return models
29
+
30
+ all_models = _load_all_models()
31
+
32
+ def load_model(source_lang, target_lang):
33
+ if source_lang == target_lang:
34
+ return _load_default_model()
35
+ model_key = (source_lang, target_lang)
36
+ if all_models.get(model_key) and all_models[model_key][0] and all_models[model_key][1]:
37
+ return all_models[model_key]
38
+ # Pivot through English
39
+ def combined_translate(text):
40
+ en_tokenizer, en_model = all_models.get(("en", "en"), _load_default_model())
41
+ if source_lang != "en":
42
+ src_to_en_tokenizer, src_to_en_model = all_models.get((source_lang, "en"), _load_model_pair(source_lang, "en")) or _load_default_model()
43
+ en_text = src_to_en_tokenizer.decode(src_to_en_model.generate(**src_to_en_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
44
+ else:
45
+ en_text = text
46
+ if target_lang != "en":
47
+ en_to_tgt_tokenizer, en_to_tgt_model = all_models.get(("en", target_lang), _load_model_pair("en", target_lang)) or _load_default_model()
48
+ return en_to_tgt_tokenizer.decode(en_to_tgt_model.generate(**en_to_tgt_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
49
+ return en_text
50
+ class CombinedModel:
51
+ def generate(self, **kwargs):
52
+ return torch.tensor([combined_translate(tokenizer.decode(x, skip_special_tokens=True)) for x in kwargs['input_ids']])
53
+ tokenizer, _ = _load_default_model()
54
+ return tokenizer, CombinedModel()
55
+
56
  @st.cache_resource
57
  def _load_default_model():
58
  model_name = "Helsinki-NLP/opus-mt-en-hi"
 
60
  model = MarianMTModel.from_pretrained(model_name)
61
  return tokenizer, model
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def translate(text, source_lang, target_lang):
64
  if not text:
65
  return ""