Krishna086 commited on
Commit
8d43482
·
verified ·
1 Parent(s): 6c6d6f8

Update translation.py

Browse files
Files changed (1) hide show
  1. translation.py +9 -5
translation.py CHANGED
@@ -36,7 +36,8 @@ class CombinedModel:
36
  def generate(self, **kwargs):
37
  return torch.tensor([combined_translate(tokenizer.decode(x, skip_special_tokens=True)) for x in kwargs['input_ids']])
38
 
39
- # Function to load appropriate translation model with optimized pivot
 
40
  def load_model(source_lang, target_lang):
41
  if source_lang == target_lang:
42
  return _load_default_model()
@@ -44,17 +45,19 @@ def load_model(source_lang, target_lang):
44
  tokenizer_model_pair = all_models.get(model_key)
45
  if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]:
46
  return tokenizer_model_pair
47
- # Optimize pivot through English using preloaded models
48
  def combined_translate(text):
49
  en_tokenizer, en_model = all_models.get(("en", "en"), _load_default_model())
50
  if source_lang != "en":
51
  src_to_en_tokenizer, src_to_en_model = all_models.get((source_lang, "en"), _load_model_pair(source_lang, "en")) or _load_default_model()
52
- en_text = src_to_en_tokenizer.decode(src_to_en_model.generate(**src_to_en_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
 
53
  else:
54
  en_text = text
55
  if target_lang != "en":
56
  en_to_tgt_tokenizer, en_to_tgt_model = all_models.get(("en", target_lang), _load_model_pair("en", target_lang)) or _load_default_model()
57
- return en_to_tgt_tokenizer.decode(en_to_tgt_model.generate(**en_to_tgt_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
 
58
  return en_text
59
  default_tokenizer, _ = _load_default_model()
60
  return default_tokenizer, CombinedModel()
@@ -67,7 +70,8 @@ def _load_default_model():
67
  model = MarianMTModel.from_pretrained(model_name)
68
  return tokenizer, model
69
 
70
- # Function to perform the translation
 
71
  def translate(text, source_lang, target_lang):
72
  if not text:
73
  return ""
 
36
  def generate(self, **kwargs):
37
  return torch.tensor([combined_translate(tokenizer.decode(x, skip_special_tokens=True)) for x in kwargs['input_ids']])
38
 
39
+ # Function to load appropriate translation model with optimized caching
40
+ @st.cache_resource
41
  def load_model(source_lang, target_lang):
42
  if source_lang == target_lang:
43
  return _load_default_model()
 
45
  tokenizer_model_pair = all_models.get(model_key)
46
  if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]:
47
  return tokenizer_model_pair
48
+ # Optimized pivot through English using preloaded models
49
  def combined_translate(text):
50
  en_tokenizer, en_model = all_models.get(("en", "en"), _load_default_model())
51
  if source_lang != "en":
52
  src_to_en_tokenizer, src_to_en_model = all_models.get((source_lang, "en"), _load_model_pair(source_lang, "en")) or _load_default_model()
53
+ with torch.no_grad():
54
+ en_text = src_to_en_tokenizer.decode(src_to_en_model.generate(**src_to_en_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
55
  else:
56
  en_text = text
57
  if target_lang != "en":
58
  en_to_tgt_tokenizer, en_to_tgt_model = all_models.get(("en", target_lang), _load_model_pair("en", target_lang)) or _load_default_model()
59
+ with torch.no_grad():
60
+ return en_to_tgt_tokenizer.decode(en_to_tgt_model.generate(**en_to_tgt_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
61
  return en_text
62
  default_tokenizer, _ = _load_default_model()
63
  return default_tokenizer, CombinedModel()
 
70
  model = MarianMTModel.from_pretrained(model_name)
71
  return tokenizer, model
72
 
73
+ # Cache translation results to improve speed
74
+ @st.cache_data
75
  def translate(text, source_lang, target_lang):
76
  if not text:
77
  return ""