Update translation.py
Browse files- translation.py +45 -27
translation.py
CHANGED
@@ -8,6 +8,51 @@ LANGUAGES = {
|
|
8 |
"ar": ("العربية", "Arabic"), "ru": ("Русский", "Russian"), "ja": ("日本語", "Japanese")
|
9 |
}
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
@st.cache_resource
|
12 |
def _load_default_model():
|
13 |
model_name = "Helsinki-NLP/opus-mt-en-hi"
|
@@ -15,33 +60,6 @@ def _load_default_model():
|
|
15 |
model = MarianMTModel.from_pretrained(model_name)
|
16 |
return tokenizer, model
|
17 |
|
18 |
-
@st.cache_resource
|
19 |
-
def load_model(source_lang, target_lang):
|
20 |
-
try:
|
21 |
-
if source_lang == target_lang:
|
22 |
-
return _load_default_model()
|
23 |
-
# Try direct model
|
24 |
-
model_name = f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}"
|
25 |
-
try:
|
26 |
-
tokenizer = MarianTokenizer.from_pretrained(model_name)
|
27 |
-
model = MarianMTModel.from_pretrained(model_name)
|
28 |
-
return tokenizer, model
|
29 |
-
except Exception:
|
30 |
-
# Pivot through English for non-English pairs
|
31 |
-
if source_lang != "en" and target_lang != "en":
|
32 |
-
def combined_translate(text):
|
33 |
-
en_tokenizer, en_model = load_model(source_lang, "en")
|
34 |
-
en_text = en_tokenizer.decode(en_model.generate(**en_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
|
35 |
-
target_tokenizer, target_model = load_model("en", target_lang)
|
36 |
-
return target_tokenizer.decode(target_model.generate(**target_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
|
37 |
-
class CombinedModel:
|
38 |
-
def generate(self, **kwargs):
|
39 |
-
return torch.tensor([combined_translate(tokenizer.decode(x, skip_special_tokens=True)) for x in kwargs['input_ids']])
|
40 |
-
return MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-hi"), CombinedModel()
|
41 |
-
return _load_default_model()
|
42 |
-
except Exception:
|
43 |
-
return _load_default_model()
|
44 |
-
|
45 |
def translate(text, source_lang, target_lang):
|
46 |
if not text:
|
47 |
return ""
|
|
|
8 |
"ar": ("العربية", "Arabic"), "ru": ("Русский", "Russian"), "ja": ("日本語", "Japanese")
|
9 |
}
|
10 |
|
11 |
+
@st.cache_resource
|
12 |
+
def _load_model_pair(source_lang, target_lang):
|
13 |
+
try:
|
14 |
+
model_name = f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}"
|
15 |
+
tokenizer = MarianTokenizer.from_pretrained(model_name)
|
16 |
+
model = MarianMTModel.from_pretrained(model_name)
|
17 |
+
return tokenizer, model
|
18 |
+
except Exception:
|
19 |
+
return None, None
|
20 |
+
|
21 |
+
@st.cache_resource
|
22 |
+
def _load_all_models():
|
23 |
+
models = {}
|
24 |
+
for src in LANGUAGES.keys():
|
25 |
+
for tgt in LANGUAGES.keys():
|
26 |
+
if src != tgt:
|
27 |
+
models[(src, tgt)] = _load_model_pair(src, tgt)
|
28 |
+
return models
|
29 |
+
|
30 |
+
all_models = _load_all_models()
|
31 |
+
|
32 |
+
def load_model(source_lang, target_lang):
|
33 |
+
if source_lang == target_lang:
|
34 |
+
return _load_default_model()
|
35 |
+
model_key = (source_lang, target_lang)
|
36 |
+
if all_models.get(model_key) and all_models[model_key][0] and all_models[model_key][1]:
|
37 |
+
return all_models[model_key]
|
38 |
+
# Pivot through English
|
39 |
+
def combined_translate(text):
|
40 |
+
en_tokenizer, en_model = all_models.get(("en", "en"), _load_default_model())
|
41 |
+
if source_lang != "en":
|
42 |
+
src_to_en_tokenizer, src_to_en_model = all_models.get((source_lang, "en"), _load_model_pair(source_lang, "en")) or _load_default_model()
|
43 |
+
en_text = src_to_en_tokenizer.decode(src_to_en_model.generate(**src_to_en_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
|
44 |
+
else:
|
45 |
+
en_text = text
|
46 |
+
if target_lang != "en":
|
47 |
+
en_to_tgt_tokenizer, en_to_tgt_model = all_models.get(("en", target_lang), _load_model_pair("en", target_lang)) or _load_default_model()
|
48 |
+
return en_to_tgt_tokenizer.decode(en_to_tgt_model.generate(**en_to_tgt_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
|
49 |
+
return en_text
|
50 |
+
class CombinedModel:
|
51 |
+
def generate(self, **kwargs):
|
52 |
+
return torch.tensor([combined_translate(tokenizer.decode(x, skip_special_tokens=True)) for x in kwargs['input_ids']])
|
53 |
+
tokenizer, _ = _load_default_model()
|
54 |
+
return tokenizer, CombinedModel()
|
55 |
+
|
56 |
@st.cache_resource
|
57 |
def _load_default_model():
|
58 |
model_name = "Helsinki-NLP/opus-mt-en-hi"
|
|
|
60 |
model = MarianMTModel.from_pretrained(model_name)
|
61 |
return tokenizer, model
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
def translate(text, source_lang, target_lang):
|
64 |
if not text:
|
65 |
return ""
|