import streamlit as st from transformers import MarianTokenizer, MarianMTModel import torch # Define supported languages LANGUAGES = { "en": ("English", "English"), "fr": ("Français", "French"), "es": ("Español", "Spanish"), "de": ("Deutsch", "German"), "hi": ("हिन्दी", "Hindi"), "zh": ("中文", "Chinese"), "ar": ("العربية", "Arabic"), "ru": ("Русский", "Russian"), "ja": ("日本語", "Japanese") } # Load a specific translation model pair with caching @st.cache_resource def _load_model_pair(source_lang, target_lang): try: model_name = f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}" tokenizer = MarianTokenizer.from_pretrained(model_name) model = MarianMTModel.from_pretrained(model_name) return tokenizer, model except Exception as e: return None, None # Suppress error message, return None for fallback # Load all possible model combinations with caching @st.cache_resource def _load_all_models(): models = {} for src in LANGUAGES.keys(): for tgt in LANGUAGES.keys(): if src != tgt: models[(src, tgt)] = _load_model_pair(src, tgt) return models # Preload all models all_models = _load_all_models() # Perform combined translation through intermediate languages def combined_translate(text, source_lang, target_lang, default_tokenizer, default_model): try: if source_lang == target_lang: # No translation needed if languages are same return text if source_lang != "en": src_to_inter_tokenizer, src_to_inter_model = None, None for inter in ["en", "fr", "es", "de", "ru"]: # Try multiple intermediates pair = all_models.get((source_lang, inter)) if pair and pair[0] and pair[1]: src_to_inter_tokenizer, src_to_inter_model = pair break inter_text = src_to_inter_tokenizer.decode(src_to_inter_model.generate(**src_to_inter_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True) if src_to_inter_tokenizer else text else: inter_text = text if target_lang != "en": inter_to_tgt_tokenizer, inter_to_tgt_model = None, None for inter in ["en", "fr", "es", "de", "ru"]: pair = all_models.get((inter, target_lang)) if pair and pair[0] and pair[1]: inter_to_tgt_tokenizer, inter_to_tgt_model = pair break translated = inter_to_tgt_tokenizer.decode(inter_to_tgt_model.generate(**inter_to_tgt_tokenizer(inter_text, return_tensors="pt", padding=True, truncation=True, max_length=1000))[0], skip_special_tokens=True) if inter_to_tgt_tokenizer else inter_text return translated if translated.strip() else text return inter_text except Exception: return text # Suppress error, return source text # Class to handle combined translation class CombinedModel: def __init__(self, source_lang, target_lang, default_tokenizer, default_model): self.source_lang = source_lang self.target_lang = target_lang self.default_tokenizer = default_tokenizer self.default_model = default_model def generate(self, **kwargs): try: input_ids = kwargs.get('input_ids') if not input_ids or input_ids.size(0) == 0: return torch.tensor([]) inputs = self.default_tokenizer.batch_decode(input_ids, skip_special_tokens=True) translated_texts = [combined_translate(text, self.source_lang, self.target_lang, self.default_tokenizer, self.default_model) for text in inputs] encoded_outputs = [self.default_tokenizer.encode(t, return_tensors="pt", padding=True, truncation=True, max_length=500)[0] for t in translated_texts] if not encoded_outputs: return torch.tensor([]) return torch.stack(encoded_outputs) # Stack tensors to ensure proper shape except Exception: return torch.tensor([]) # Suppress error, return empty tensor # Load appropriate translation model with caching @st.cache_resource def load_model(source_lang, target_lang): try: if source_lang == target_lang: return _load_default_model() model_key = (source_lang, target_lang) tokenizer_model_pair = all_models.get(model_key) if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]: return tokenizer_model_pair for inter in LANGUAGES.keys(): if inter != source_lang and inter != target_lang: pair1 = all_models.get((source_lang, inter)) pair2 = all_models.get((inter, target_lang)) if pair1 and pair1[0] and pair1[1] and pair2 and pair2[0] and pair2[1]: return pair1 default_tokenizer, default_model = _load_default_model() return default_tokenizer, CombinedModel(source_lang, target_lang, default_tokenizer, default_model) except Exception: raise # Allow higher-level handling if needed # Load default translation model with caching @st.cache_resource def _load_default_model(): try: model_name = "Helsinki-NLP/opus-mt-en-hi" tokenizer = MarianTokenizer.from_pretrained(model_name) model = MarianMTModel.from_pretrained(model_name) return tokenizer, model except Exception: raise # Allow higher-level handling if needed # Translate text with caching @st.cache_data def translate(text, source_lang, target_lang): try: if not text: return "" tokenizer, model = load_model(source_lang, target_lang) inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500) if inputs['input_ids'].size(0) > 1: inputs = {k: v[0].unsqueeze(0) for k, v in inputs.items()} with torch.no_grad(): translated_ids = model.generate(**inputs, max_length=1000 if target_lang in ["hi", "zh", "ja"] else 500, num_beams=4, early_stopping=True) result = tokenizer.decode(translated_ids[0], skip_special_tokens=True) if translated_ids.size(0) > 0 else None if result and result.strip(): return result else: st.warning("This translation is not possible at this moment. Please try another language.") return text # Return source text without additional note except Exception: st.warning("This translation is not possible at this moment. Please try another language.") return text # Return source text without additional note