Krishna086 commited on
Commit
b38d3f8
·
verified ·
1 Parent(s): 17b4050

Update translation.py

Browse files
Files changed (1) hide show
  1. translation.py +26 -21
translation.py CHANGED
@@ -34,15 +34,17 @@ all_models = _load_all_models()
34
  # Define combined_translate outside load_model with explicit parameters
35
  def combined_translate(text, source_lang, target_lang, default_tokenizer, default_model):
36
  with torch.no_grad():
37
- if source_lang != "en":
38
- src_to_en_tokenizer, src_to_en_model = all_models.get((source_lang, "en"), (default_tokenizer, default_model))
39
- en_text = src_to_en_tokenizer.decode(src_to_en_model.generate(**src_to_en_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
40
- else:
41
- en_text = text
42
- if target_lang != "en":
43
- en_to_tgt_tokenizer, en_to_tgt_model = all_models.get(("en", target_lang), (default_tokenizer, default_model))
44
- return en_to_tgt_tokenizer.decode(en_to_tgt_model.generate(**en_to_tgt_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=1000))[0], skip_special_tokens=True)
45
- return en_text
 
 
46
 
47
  # Class to handle combined translation through English pivot
48
  class CombinedModel:
@@ -54,9 +56,11 @@ class CombinedModel:
54
 
55
  def generate(self, **kwargs):
56
  input_ids = kwargs.get('input_ids')
57
- if not input_ids:
58
  return torch.tensor([])
59
- return torch.tensor([combined_translate(self.default_tokenizer.decode(x, skip_special_tokens=True), self.source_lang, self.target_lang, self.default_tokenizer, self.default_model) for x in input_ids])
 
 
60
 
61
  # Function to load appropriate translation model with optimized caching
62
  @st.cache_resource
@@ -67,13 +71,14 @@ def load_model(source_lang, target_lang):
67
  tokenizer_model_pair = all_models.get(model_key)
68
  if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]:
69
  return tokenizer_model_pair
70
- # Prefer direct model if available, then pivot
71
- for src in [source_lang, "en"]:
72
- for tgt in [target_lang, "en"]:
73
- if src != tgt:
74
- pair = all_models.get((src, tgt))
75
- if pair and pair[0] and pair[1]:
76
- return pair
 
77
  default_tokenizer, default_model = _load_default_model()
78
  return default_tokenizer, CombinedModel(source_lang, target_lang, default_tokenizer, default_model)
79
 
@@ -93,11 +98,11 @@ def translate(text, source_lang, target_lang):
93
  try:
94
  tokenizer, model = load_model(source_lang, target_lang)
95
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
96
- if inputs['input_ids'].size(0) > 1: # Ensure single sequence
97
  inputs = {k: v[0].unsqueeze(0) for k, v in inputs.items()}
98
  with torch.no_grad():
99
- translated = model.generate(**inputs, max_length=1000 if target_lang == "ja" else 500, num_beams=4, early_stopping=True)
100
- result = tokenizer.decode(translated[0], skip_special_tokens=True)
101
  return result if result.strip() else text
102
  except Exception as e:
103
  st.error(f"Translation error: {e}")
 
34
  # Define combined_translate outside load_model with explicit parameters
35
  def combined_translate(text, source_lang, target_lang, default_tokenizer, default_model):
36
  with torch.no_grad():
37
+ if source_lang != target_lang: # Only translate if languages differ
38
+ if source_lang != "en":
39
+ src_to_en_tokenizer, src_to_en_model = all_models.get((source_lang, "en"), (default_tokenizer, default_model))
40
+ en_text = src_to_en_tokenizer.decode(src_to_en_model.generate(**src_to_en_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
41
+ else:
42
+ en_text = text
43
+ if target_lang != "en":
44
+ en_to_tgt_tokenizer, en_to_tgt_model = all_models.get(("en", target_lang), (default_tokenizer, default_model))
45
+ translated = en_to_tgt_tokenizer.decode(en_to_tgt_model.generate(**en_to_tgt_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=1000))[0], skip_special_tokens=True)
46
+ return translated if translated.strip() else text
47
+ return text
48
 
49
  # Class to handle combined translation through English pivot
50
  class CombinedModel:
 
56
 
57
  def generate(self, **kwargs):
58
  input_ids = kwargs.get('input_ids')
59
+ if not input_ids or input_ids.size(0) == 0:
60
  return torch.tensor([])
61
+ inputs = self.default_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
62
+ translated = [combined_translate(text, self.source_lang, self.target_lang, self.default_tokenizer, self.default_model) for text in inputs]
63
+ return torch.tensor([self.default_tokenizer.encode(t, return_tensors="pt", padding=True, truncation=True, max_length=500)[0] for t in translated])
64
 
65
  # Function to load appropriate translation model with optimized caching
66
  @st.cache_resource
 
71
  tokenizer_model_pair = all_models.get(model_key)
72
  if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]:
73
  return tokenizer_model_pair
74
+ # Try to find the best path through any intermediate language
75
+ for inter in LANGUAGES.keys():
76
+ if inter != source_lang and inter != target_lang:
77
+ pair1 = all_models.get((source_lang, inter))
78
+ pair2 = all_models.get((inter, target_lang))
79
+ if pair1 and pair1[0] and pair1[1] and pair2 and pair2[0] and pair2[1]:
80
+ return pair1
81
+ # Fallback to pivot through English
82
  default_tokenizer, default_model = _load_default_model()
83
  return default_tokenizer, CombinedModel(source_lang, target_lang, default_tokenizer, default_model)
84
 
 
98
  try:
99
  tokenizer, model = load_model(source_lang, target_lang)
100
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
101
+ if inputs['input_ids'].size(0) > 1:
102
  inputs = {k: v[0].unsqueeze(0) for k, v in inputs.items()}
103
  with torch.no_grad():
104
+ translated_ids = model.generate(**inputs, max_length=1000 if target_lang in ["hi", "ja"] else 500, num_beams=4, early_stopping=True)
105
+ result = tokenizer.decode(translated_ids[0], skip_special_tokens=True)
106
  return result if result.strip() else text
107
  except Exception as e:
108
  st.error(f"Translation error: {e}")