Update translation.py
Browse files- translation.py +14 -5
translation.py
CHANGED
@@ -77,8 +77,11 @@ class CombinedModel:
|
|
77 |
if not input_ids or input_ids.size(0) == 0:
|
78 |
return torch.tensor([])
|
79 |
inputs = self.default_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
82 |
except Exception as e:
|
83 |
st.error(f"Generation error in CombinedModel: {e}")
|
84 |
return torch.tensor([])
|
@@ -129,8 +132,14 @@ def translate(text, source_lang, target_lang):
|
|
129 |
inputs = {k: v[0].unsqueeze(0) for k, v in inputs.items()}
|
130 |
with torch.no_grad():
|
131 |
translated_ids = model.generate(**inputs, max_length=1000 if target_lang in ["hi", "zh", "ja"] else 500, num_beams=4, early_stopping=True)
|
132 |
-
result = tokenizer.decode(translated_ids[0], skip_special_tokens=True)
|
133 |
-
|
|
|
|
|
|
|
|
|
|
|
134 |
except Exception as e:
|
135 |
st.error(f"Translation failed: {e}")
|
136 |
-
|
|
|
|
77 |
if not input_ids or input_ids.size(0) == 0:
|
78 |
return torch.tensor([])
|
79 |
inputs = self.default_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
|
80 |
+
translated_texts = [combined_translate(text, self.source_lang, self.target_lang, self.default_tokenizer, self.default_model) for text in inputs]
|
81 |
+
encoded_outputs = [self.default_tokenizer.encode(t, return_tensors="pt", padding=True, truncation=True, max_length=500)[0] for t in translated_texts]
|
82 |
+
if not encoded_outputs:
|
83 |
+
return torch.tensor([])
|
84 |
+
return torch.stack(encoded_outputs) # Stack tensors to ensure proper shape
|
85 |
except Exception as e:
|
86 |
st.error(f"Generation error in CombinedModel: {e}")
|
87 |
return torch.tensor([])
|
|
|
132 |
inputs = {k: v[0].unsqueeze(0) for k, v in inputs.items()}
|
133 |
with torch.no_grad():
|
134 |
translated_ids = model.generate(**inputs, max_length=1000 if target_lang in ["hi", "zh", "ja"] else 500, num_beams=4, early_stopping=True)
|
135 |
+
result = tokenizer.decode(translated_ids[0], skip_special_tokens=True) if translated_ids.size(0) > 0 else None
|
136 |
+
if result and result.strip():
|
137 |
+
return result
|
138 |
+
else:
|
139 |
+
message = "This translation is not possible at this moment. Please try another language."
|
140 |
+
st.warning(message)
|
141 |
+
return f"{text} (Note: {message})"
|
142 |
except Exception as e:
|
143 |
st.error(f"Translation failed: {e}")
|
144 |
+
message = "This translation is not possible at this moment. Please try another language."
|
145 |
+
return f"{text} (Note: {message})"
|