Update translation.py
Browse files- translation.py +3 -3
translation.py
CHANGED
@@ -45,7 +45,7 @@ def load_model(source_lang, target_lang):
|
|
45 |
tokenizer_model_pair = all_models.get(model_key)
|
46 |
if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]:
|
47 |
return tokenizer_model_pair
|
48 |
-
# Use direct English pivot
|
49 |
if source_lang != "en" and target_lang != "en":
|
50 |
en_pivot_pair = all_models.get((source_lang, "en")) or _load_model_pair(source_lang, "en")
|
51 |
if en_pivot_pair[0] and en_pivot_pair[1]:
|
@@ -60,7 +60,7 @@ def load_model(source_lang, target_lang):
|
|
60 |
en_text = text
|
61 |
if target_lang != "en":
|
62 |
en_to_tgt_tokenizer, en_to_tgt_model = all_models.get(("en", target_lang), _load_default_model())
|
63 |
-
return en_to_tgt_tokenizer.decode(en_to_tgt_model.generate(**en_to_tgt_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=
|
64 |
return en_text
|
65 |
return default_tokenizer, CombinedModel()
|
66 |
|
@@ -81,7 +81,7 @@ def translate(text, source_lang, target_lang):
|
|
81 |
tokenizer, model = load_model(source_lang, target_lang)
|
82 |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
|
83 |
with torch.no_grad():
|
84 |
-
translated = model.generate(**inputs, max_length=500, num_beams=4, early_stopping=True) #
|
85 |
result = tokenizer.decode(translated[0], skip_special_tokens=True)
|
86 |
return result if result.strip() else text
|
87 |
except Exception as e:
|
|
|
45 |
tokenizer_model_pair = all_models.get(model_key)
|
46 |
if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]:
|
47 |
return tokenizer_model_pair
|
48 |
+
# Use direct English pivot with defined combined_translate
|
49 |
if source_lang != "en" and target_lang != "en":
|
50 |
en_pivot_pair = all_models.get((source_lang, "en")) or _load_model_pair(source_lang, "en")
|
51 |
if en_pivot_pair[0] and en_pivot_pair[1]:
|
|
|
60 |
en_text = text
|
61 |
if target_lang != "en":
|
62 |
en_to_tgt_tokenizer, en_to_tgt_model = all_models.get(("en", target_lang), _load_default_model())
|
63 |
+
return en_to_tgt_tokenizer.decode(en_to_tgt_model.generate(**en_to_tgt_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=1000))[0], skip_special_tokens=True) # Increased max_length
|
64 |
return en_text
|
65 |
return default_tokenizer, CombinedModel()
|
66 |
|
|
|
81 |
tokenizer, model = load_model(source_lang, target_lang)
|
82 |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
|
83 |
with torch.no_grad():
|
84 |
+
translated = model.generate(**inputs, max_length=1000 if target_lang == "hi" else 500, num_beams=6 if target_lang == "hi" else 4, early_stopping=True) # Adjusted for Hindi
|
85 |
result = tokenizer.decode(translated[0], skip_special_tokens=True)
|
86 |
return result if result.strip() else text
|
87 |
except Exception as e:
|