Update translation.py
Browse files- translation.py +26 -21
translation.py
CHANGED
@@ -34,15 +34,17 @@ all_models = _load_all_models()
|
|
34 |
# Define combined_translate outside load_model with explicit parameters
|
35 |
def combined_translate(text, source_lang, target_lang, default_tokenizer, default_model):
|
36 |
with torch.no_grad():
|
37 |
-
if source_lang !=
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
46 |
|
47 |
# Class to handle combined translation through English pivot
|
48 |
class CombinedModel:
|
@@ -54,9 +56,11 @@ class CombinedModel:
|
|
54 |
|
55 |
def generate(self, **kwargs):
|
56 |
input_ids = kwargs.get('input_ids')
|
57 |
-
if not input_ids:
|
58 |
return torch.tensor([])
|
59 |
-
|
|
|
|
|
60 |
|
61 |
# Function to load appropriate translation model with optimized caching
|
62 |
@st.cache_resource
|
@@ -67,13 +71,14 @@ def load_model(source_lang, target_lang):
|
|
67 |
tokenizer_model_pair = all_models.get(model_key)
|
68 |
if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]:
|
69 |
return tokenizer_model_pair
|
70 |
-
#
|
71 |
-
for
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
|
|
77 |
default_tokenizer, default_model = _load_default_model()
|
78 |
return default_tokenizer, CombinedModel(source_lang, target_lang, default_tokenizer, default_model)
|
79 |
|
@@ -93,11 +98,11 @@ def translate(text, source_lang, target_lang):
|
|
93 |
try:
|
94 |
tokenizer, model = load_model(source_lang, target_lang)
|
95 |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
|
96 |
-
if inputs['input_ids'].size(0) > 1:
|
97 |
inputs = {k: v[0].unsqueeze(0) for k, v in inputs.items()}
|
98 |
with torch.no_grad():
|
99 |
-
|
100 |
-
result = tokenizer.decode(
|
101 |
return result if result.strip() else text
|
102 |
except Exception as e:
|
103 |
st.error(f"Translation error: {e}")
|
|
|
34 |
# Define combined_translate outside load_model with explicit parameters
|
35 |
def combined_translate(text, source_lang, target_lang, default_tokenizer, default_model):
|
36 |
with torch.no_grad():
|
37 |
+
if source_lang != target_lang: # Only translate if languages differ
|
38 |
+
if source_lang != "en":
|
39 |
+
src_to_en_tokenizer, src_to_en_model = all_models.get((source_lang, "en"), (default_tokenizer, default_model))
|
40 |
+
en_text = src_to_en_tokenizer.decode(src_to_en_model.generate(**src_to_en_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
|
41 |
+
else:
|
42 |
+
en_text = text
|
43 |
+
if target_lang != "en":
|
44 |
+
en_to_tgt_tokenizer, en_to_tgt_model = all_models.get(("en", target_lang), (default_tokenizer, default_model))
|
45 |
+
translated = en_to_tgt_tokenizer.decode(en_to_tgt_model.generate(**en_to_tgt_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=1000))[0], skip_special_tokens=True)
|
46 |
+
return translated if translated.strip() else text
|
47 |
+
return text
|
48 |
|
49 |
# Class to handle combined translation through English pivot
|
50 |
class CombinedModel:
|
|
|
56 |
|
57 |
def generate(self, **kwargs):
|
58 |
input_ids = kwargs.get('input_ids')
|
59 |
+
if not input_ids or input_ids.size(0) == 0:
|
60 |
return torch.tensor([])
|
61 |
+
inputs = self.default_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
|
62 |
+
translated = [combined_translate(text, self.source_lang, self.target_lang, self.default_tokenizer, self.default_model) for text in inputs]
|
63 |
+
return torch.tensor([self.default_tokenizer.encode(t, return_tensors="pt", padding=True, truncation=True, max_length=500)[0] for t in translated])
|
64 |
|
65 |
# Function to load appropriate translation model with optimized caching
|
66 |
@st.cache_resource
|
|
|
71 |
tokenizer_model_pair = all_models.get(model_key)
|
72 |
if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]:
|
73 |
return tokenizer_model_pair
|
74 |
+
# Try to find the best path through any intermediate language
|
75 |
+
for inter in LANGUAGES.keys():
|
76 |
+
if inter != source_lang and inter != target_lang:
|
77 |
+
pair1 = all_models.get((source_lang, inter))
|
78 |
+
pair2 = all_models.get((inter, target_lang))
|
79 |
+
if pair1 and pair1[0] and pair1[1] and pair2 and pair2[0] and pair2[1]:
|
80 |
+
return pair1
|
81 |
+
# Fallback to pivot through English
|
82 |
default_tokenizer, default_model = _load_default_model()
|
83 |
return default_tokenizer, CombinedModel(source_lang, target_lang, default_tokenizer, default_model)
|
84 |
|
|
|
98 |
try:
|
99 |
tokenizer, model = load_model(source_lang, target_lang)
|
100 |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
|
101 |
+
if inputs['input_ids'].size(0) > 1:
|
102 |
inputs = {k: v[0].unsqueeze(0) for k, v in inputs.items()}
|
103 |
with torch.no_grad():
|
104 |
+
translated_ids = model.generate(**inputs, max_length=1000 if target_lang in ["hi", "ja"] else 500, num_beams=4, early_stopping=True)
|
105 |
+
result = tokenizer.decode(translated_ids[0], skip_special_tokens=True)
|
106 |
return result if result.strip() else text
|
107 |
except Exception as e:
|
108 |
st.error(f"Translation error: {e}")
|