Update translation.py
Browse files- translation.py +9 -5
translation.py
CHANGED
@@ -36,7 +36,8 @@ class CombinedModel:
|
|
36 |
def generate(self, **kwargs):
|
37 |
return torch.tensor([combined_translate(tokenizer.decode(x, skip_special_tokens=True)) for x in kwargs['input_ids']])
|
38 |
|
39 |
-
# Function to load appropriate translation model with optimized
|
|
|
40 |
def load_model(source_lang, target_lang):
|
41 |
if source_lang == target_lang:
|
42 |
return _load_default_model()
|
@@ -44,17 +45,19 @@ def load_model(source_lang, target_lang):
|
|
44 |
tokenizer_model_pair = all_models.get(model_key)
|
45 |
if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]:
|
46 |
return tokenizer_model_pair
|
47 |
-
#
|
48 |
def combined_translate(text):
|
49 |
en_tokenizer, en_model = all_models.get(("en", "en"), _load_default_model())
|
50 |
if source_lang != "en":
|
51 |
src_to_en_tokenizer, src_to_en_model = all_models.get((source_lang, "en"), _load_model_pair(source_lang, "en")) or _load_default_model()
|
52 |
-
|
|
|
53 |
else:
|
54 |
en_text = text
|
55 |
if target_lang != "en":
|
56 |
en_to_tgt_tokenizer, en_to_tgt_model = all_models.get(("en", target_lang), _load_model_pair("en", target_lang)) or _load_default_model()
|
57 |
-
|
|
|
58 |
return en_text
|
59 |
default_tokenizer, _ = _load_default_model()
|
60 |
return default_tokenizer, CombinedModel()
|
@@ -67,7 +70,8 @@ def _load_default_model():
|
|
67 |
model = MarianMTModel.from_pretrained(model_name)
|
68 |
return tokenizer, model
|
69 |
|
70 |
-
#
|
|
|
71 |
def translate(text, source_lang, target_lang):
|
72 |
if not text:
|
73 |
return ""
|
|
|
36 |
def generate(self, **kwargs):
|
37 |
return torch.tensor([combined_translate(tokenizer.decode(x, skip_special_tokens=True)) for x in kwargs['input_ids']])
|
38 |
|
39 |
+
# Function to load appropriate translation model with optimized caching
|
40 |
+
@st.cache_resource
|
41 |
def load_model(source_lang, target_lang):
|
42 |
if source_lang == target_lang:
|
43 |
return _load_default_model()
|
|
|
45 |
tokenizer_model_pair = all_models.get(model_key)
|
46 |
if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]:
|
47 |
return tokenizer_model_pair
|
48 |
+
# Optimized pivot through English using preloaded models
|
49 |
def combined_translate(text):
|
50 |
en_tokenizer, en_model = all_models.get(("en", "en"), _load_default_model())
|
51 |
if source_lang != "en":
|
52 |
src_to_en_tokenizer, src_to_en_model = all_models.get((source_lang, "en"), _load_model_pair(source_lang, "en")) or _load_default_model()
|
53 |
+
with torch.no_grad():
|
54 |
+
en_text = src_to_en_tokenizer.decode(src_to_en_model.generate(**src_to_en_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
|
55 |
else:
|
56 |
en_text = text
|
57 |
if target_lang != "en":
|
58 |
en_to_tgt_tokenizer, en_to_tgt_model = all_models.get(("en", target_lang), _load_model_pair("en", target_lang)) or _load_default_model()
|
59 |
+
with torch.no_grad():
|
60 |
+
return en_to_tgt_tokenizer.decode(en_to_tgt_model.generate(**en_to_tgt_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
|
61 |
return en_text
|
62 |
default_tokenizer, _ = _load_default_model()
|
63 |
return default_tokenizer, CombinedModel()
|
|
|
70 |
model = MarianMTModel.from_pretrained(model_name)
|
71 |
return tokenizer, model
|
72 |
|
73 |
+
# Cache translation results to improve speed
|
74 |
+
@st.cache_data
|
75 |
def translate(text, source_lang, target_lang):
|
76 |
if not text:
|
77 |
return ""
|