Krishna086 commited on
Commit
8bd35b3
·
verified ·
1 Parent(s): 035bcdc

Update translation.py

Browse files
Files changed (1) hide show
  1. translation.py +14 -5
translation.py CHANGED
@@ -77,8 +77,11 @@ class CombinedModel:
77
  if not input_ids or input_ids.size(0) == 0:
78
  return torch.tensor([])
79
  inputs = self.default_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
80
- translated = [combined_translate(text, self.source_lang, self.target_lang, self.default_tokenizer, self.default_model) for text in inputs]
81
- return torch.tensor([self.default_tokenizer.encode(t, return_tensors="pt", padding=True, truncation=True, max_length=500)[0] for t in translated])
 
 
 
82
  except Exception as e:
83
  st.error(f"Generation error in CombinedModel: {e}")
84
  return torch.tensor([])
@@ -129,8 +132,14 @@ def translate(text, source_lang, target_lang):
129
  inputs = {k: v[0].unsqueeze(0) for k, v in inputs.items()}
130
  with torch.no_grad():
131
  translated_ids = model.generate(**inputs, max_length=1000 if target_lang in ["hi", "zh", "ja"] else 500, num_beams=4, early_stopping=True)
132
- result = tokenizer.decode(translated_ids[0], skip_special_tokens=True)
133
- return result if result.strip() else text
 
 
 
 
 
134
  except Exception as e:
135
  st.error(f"Translation failed: {e}")
136
- return text
 
 
77
  if not input_ids or input_ids.size(0) == 0:
78
  return torch.tensor([])
79
  inputs = self.default_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
80
+ translated_texts = [combined_translate(text, self.source_lang, self.target_lang, self.default_tokenizer, self.default_model) for text in inputs]
81
+ encoded_outputs = [self.default_tokenizer.encode(t, return_tensors="pt", padding=True, truncation=True, max_length=500)[0] for t in translated_texts]
82
+ if not encoded_outputs:
83
+ return torch.tensor([])
84
+ return torch.stack(encoded_outputs) # Stack tensors to ensure proper shape
85
  except Exception as e:
86
  st.error(f"Generation error in CombinedModel: {e}")
87
  return torch.tensor([])
 
132
  inputs = {k: v[0].unsqueeze(0) for k, v in inputs.items()}
133
  with torch.no_grad():
134
  translated_ids = model.generate(**inputs, max_length=1000 if target_lang in ["hi", "zh", "ja"] else 500, num_beams=4, early_stopping=True)
135
+ result = tokenizer.decode(translated_ids[0], skip_special_tokens=True) if translated_ids.size(0) > 0 else None
136
+ if result and result.strip():
137
+ return result
138
+ else:
139
+ message = "This translation is not possible at this moment. Please try another language."
140
+ st.warning(message)
141
+ return f"{text} (Note: {message})"
142
  except Exception as e:
143
  st.error(f"Translation failed: {e}")
144
+ message = "This translation is not possible at this moment. Please try another language."
145
+ return f"{text} (Note: {message})"