|
import streamlit as st |
|
from transformers import MarianTokenizer, MarianMTModel |
|
import torch |
|
|
|
LANGUAGES = { |
|
"en": ("English", "English"), "fr": ("Français", "French"), "es": ("Español", "Spanish"), |
|
"de": ("Deutsch", "German"), "hi": ("हिन्दी", "Hindi"), "zh": ("中文", "Chinese"), |
|
"ar": ("العربية", "Arabic"), "ru": ("Русский", "Russian"), "ja": ("日本語", "Japanese") |
|
} |
|
|
|
@st.cache_resource |
|
def _load_default_model(): |
|
model_name = "Helsinki-NLP/opus-mt-en-hi" |
|
tokenizer = MarianTokenizer.from_pretrained(model_name) |
|
model = MarianMTModel.from_pretrained(model_name) |
|
return tokenizer, model |
|
|
|
@st.cache_resource |
|
def load_model(source_lang, target_lang): |
|
try: |
|
if source_lang == target_lang: |
|
return _load_default_model() |
|
|
|
model_name = f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}" |
|
try: |
|
tokenizer = MarianTokenizer.from_pretrained(model_name) |
|
model = MarianMTModel.from_pretrained(model_name) |
|
return tokenizer, model |
|
except Exception: |
|
|
|
if source_lang != "en" and target_lang != "en": |
|
en_to_target = load_model("en", target_lang) |
|
source_to_en = load_model(source_lang, "en") |
|
return source_to_en if source_lang == "en" else en_to_target |
|
return _load_default_model() |
|
except Exception: |
|
return _load_default_model() |
|
|
|
def translate(text, source_lang, target_lang): |
|
if not text: |
|
return "" |
|
try: |
|
tokenizer, model = load_model(source_lang, target_lang) |
|
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500) |
|
with torch.no_grad(): |
|
translated = model.generate(**inputs, max_length=500, num_beams=2, early_stopping=True) |
|
return tokenizer.decode(translated[0], skip_special_tokens=True) |
|
except Exception: |
|
return text |