File size: 6,763 Bytes
76b06b6 b912ba6 122790b b912ba6 e95d8e4 a6fc19e e95d8e4 71d9b1a e95d8e4 89fcb89 71d9b1a e95d8e4 71d9b1a e95d8e4 71d9b1a e95d8e4 3aee48b e95d8e4 89fcb89 3aee48b e95d8e4 83c5c51 3aee48b 83c5c51 e95d8e4 8bd35b3 89fcb89 83c5c51 e95d8e4 8d43482 71d9b1a e95d8e4 89fcb89 71d9b1a e95d8e4 e22e364 e95d8e4 89fcb89 b912ba6 e95d8e4 8d43482 122790b dc03f7a e95d8e4 dd34156 b38d3f8 17b4050 dd34156 8f10e97 8bd35b3 89fcb89 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import streamlit as st
from transformers import MarianTokenizer, MarianMTModel
import torch
# Define supported languages
LANGUAGES = {
"en": ("English", "English"), "fr": ("Français", "French"), "es": ("Español", "Spanish"),
"de": ("Deutsch", "German"), "hi": ("हिन्दी", "Hindi"), "zh": ("中文", "Chinese"),
"ar": ("العربية", "Arabic"), "ru": ("Русский", "Russian"), "ja": ("日本語", "Japanese")
}
# Load a specific translation model pair with caching
@st.cache_resource
def _load_model_pair(source_lang, target_lang):
try:
model_name = f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}"
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
return tokenizer, model
except Exception as e:
return None, None # Suppress error message, return None for fallback
# Load all possible model combinations with caching
@st.cache_resource
def _load_all_models():
models = {}
for src in LANGUAGES.keys():
for tgt in LANGUAGES.keys():
if src != tgt:
models[(src, tgt)] = _load_model_pair(src, tgt)
return models
# Preload all models
all_models = _load_all_models()
# Perform combined translation through intermediate languages
def combined_translate(text, source_lang, target_lang, default_tokenizer, default_model):
try:
if source_lang == target_lang: # No translation needed if languages are same
return text
if source_lang != "en":
src_to_inter_tokenizer, src_to_inter_model = None, None
for inter in ["en", "fr", "es", "de", "ru"]: # Try multiple intermediates
pair = all_models.get((source_lang, inter))
if pair and pair[0] and pair[1]:
src_to_inter_tokenizer, src_to_inter_model = pair
break
inter_text = src_to_inter_tokenizer.decode(src_to_inter_model.generate(**src_to_inter_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True) if src_to_inter_tokenizer else text
else:
inter_text = text
if target_lang != "en":
inter_to_tgt_tokenizer, inter_to_tgt_model = None, None
for inter in ["en", "fr", "es", "de", "ru"]:
pair = all_models.get((inter, target_lang))
if pair and pair[0] and pair[1]:
inter_to_tgt_tokenizer, inter_to_tgt_model = pair
break
translated = inter_to_tgt_tokenizer.decode(inter_to_tgt_model.generate(**inter_to_tgt_tokenizer(inter_text, return_tensors="pt", padding=True, truncation=True, max_length=1000))[0], skip_special_tokens=True) if inter_to_tgt_tokenizer else inter_text
return translated if translated.strip() else text
return inter_text
except Exception:
return text # Suppress error, return source text
# Class to handle combined translation
class CombinedModel:
def __init__(self, source_lang, target_lang, default_tokenizer, default_model):
self.source_lang = source_lang
self.target_lang = target_lang
self.default_tokenizer = default_tokenizer
self.default_model = default_model
def generate(self, **kwargs):
try:
input_ids = kwargs.get('input_ids')
if not input_ids or input_ids.size(0) == 0:
return torch.tensor([])
inputs = self.default_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
translated_texts = [combined_translate(text, self.source_lang, self.target_lang, self.default_tokenizer, self.default_model) for text in inputs]
encoded_outputs = [self.default_tokenizer.encode(t, return_tensors="pt", padding=True, truncation=True, max_length=500)[0] for t in translated_texts]
if not encoded_outputs:
return torch.tensor([])
return torch.stack(encoded_outputs) # Stack tensors to ensure proper shape
except Exception:
return torch.tensor([]) # Suppress error, return empty tensor
# Load appropriate translation model with caching
@st.cache_resource
def load_model(source_lang, target_lang):
try:
if source_lang == target_lang:
return _load_default_model()
model_key = (source_lang, target_lang)
tokenizer_model_pair = all_models.get(model_key)
if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]:
return tokenizer_model_pair
for inter in LANGUAGES.keys():
if inter != source_lang and inter != target_lang:
pair1 = all_models.get((source_lang, inter))
pair2 = all_models.get((inter, target_lang))
if pair1 and pair1[0] and pair1[1] and pair2 and pair2[0] and pair2[1]:
return pair1
default_tokenizer, default_model = _load_default_model()
return default_tokenizer, CombinedModel(source_lang, target_lang, default_tokenizer, default_model)
except Exception:
raise # Allow higher-level handling if needed
# Load default translation model with caching
@st.cache_resource
def _load_default_model():
try:
model_name = "Helsinki-NLP/opus-mt-en-hi"
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
return tokenizer, model
except Exception:
raise # Allow higher-level handling if needed
# Translate text with caching
@st.cache_data
def translate(text, source_lang, target_lang):
try:
if not text:
return ""
tokenizer, model = load_model(source_lang, target_lang)
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
if inputs['input_ids'].size(0) > 1:
inputs = {k: v[0].unsqueeze(0) for k, v in inputs.items()}
with torch.no_grad():
translated_ids = model.generate(**inputs, max_length=1000 if target_lang in ["hi", "zh", "ja"] else 500, num_beams=4, early_stopping=True)
result = tokenizer.decode(translated_ids[0], skip_special_tokens=True) if translated_ids.size(0) > 0 else None
if result and result.strip():
return result
else:
st.warning("This translation is not possible at this moment. Please try another language.")
return text # Return source text without additional note
except Exception:
st.warning("This translation is not possible at this moment. Please try another language.")
return text # Return source text without additional note |