File size: 6,763 Bytes
76b06b6
b912ba6
122790b
b912ba6
e95d8e4
a6fc19e
 
 
 
 
 
e95d8e4
71d9b1a
 
 
 
 
 
 
e95d8e4
89fcb89
71d9b1a
e95d8e4
71d9b1a
 
 
 
 
 
 
 
 
e95d8e4
71d9b1a
 
e95d8e4
3aee48b
e95d8e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89fcb89
 
3aee48b
e95d8e4
83c5c51
3aee48b
 
 
 
 
 
83c5c51
e95d8e4
 
 
 
 
8bd35b3
 
 
 
 
89fcb89
 
83c5c51
e95d8e4
8d43482
71d9b1a
e95d8e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89fcb89
 
71d9b1a
e95d8e4
e22e364
 
e95d8e4
 
 
 
 
89fcb89
 
b912ba6
e95d8e4
8d43482
122790b
dc03f7a
e95d8e4
 
dd34156
 
b38d3f8
17b4050
dd34156
8f10e97
8bd35b3
 
 
 
89fcb89
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import streamlit as st
from transformers import MarianTokenizer, MarianMTModel
import torch

# Define supported languages
LANGUAGES = {
    "en": ("English", "English"), "fr": ("Français", "French"), "es": ("Español", "Spanish"),
    "de": ("Deutsch", "German"), "hi": ("हिन्दी", "Hindi"), "zh": ("中文", "Chinese"),
    "ar": ("العربية", "Arabic"), "ru": ("Русский", "Russian"), "ja": ("日本語", "Japanese")
}

# Load a specific translation model pair with caching
@st.cache_resource
def _load_model_pair(source_lang, target_lang):
    try:
        model_name = f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}"
        tokenizer = MarianTokenizer.from_pretrained(model_name)
        model = MarianMTModel.from_pretrained(model_name)
        return tokenizer, model
    except Exception as e:
        return None, None  # Suppress error message, return None for fallback

# Load all possible model combinations with caching
@st.cache_resource
def _load_all_models():
    models = {}
    for src in LANGUAGES.keys():
        for tgt in LANGUAGES.keys():
            if src != tgt:
                models[(src, tgt)] = _load_model_pair(src, tgt)
    return models

# Preload all models
all_models = _load_all_models()

# Perform combined translation through intermediate languages
def combined_translate(text, source_lang, target_lang, default_tokenizer, default_model):
    try:
        if source_lang == target_lang:  # No translation needed if languages are same
            return text
        if source_lang != "en":
            src_to_inter_tokenizer, src_to_inter_model = None, None
            for inter in ["en", "fr", "es", "de", "ru"]:  # Try multiple intermediates
                pair = all_models.get((source_lang, inter))
                if pair and pair[0] and pair[1]:
                    src_to_inter_tokenizer, src_to_inter_model = pair
                    break
            inter_text = src_to_inter_tokenizer.decode(src_to_inter_model.generate(**src_to_inter_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True) if src_to_inter_tokenizer else text
        else:
            inter_text = text
        if target_lang != "en":
            inter_to_tgt_tokenizer, inter_to_tgt_model = None, None
            for inter in ["en", "fr", "es", "de", "ru"]:
                pair = all_models.get((inter, target_lang))
                if pair and pair[0] and pair[1]:
                    inter_to_tgt_tokenizer, inter_to_tgt_model = pair
                    break
            translated = inter_to_tgt_tokenizer.decode(inter_to_tgt_model.generate(**inter_to_tgt_tokenizer(inter_text, return_tensors="pt", padding=True, truncation=True, max_length=1000))[0], skip_special_tokens=True) if inter_to_tgt_tokenizer else inter_text
            return translated if translated.strip() else text
        return inter_text
    except Exception:
        return text  # Suppress error, return source text

# Class to handle combined translation
class CombinedModel:
    def __init__(self, source_lang, target_lang, default_tokenizer, default_model):
        self.source_lang = source_lang
        self.target_lang = target_lang
        self.default_tokenizer = default_tokenizer
        self.default_model = default_model

    def generate(self, **kwargs):
        try:
            input_ids = kwargs.get('input_ids')
            if not input_ids or input_ids.size(0) == 0:
                return torch.tensor([])
            inputs = self.default_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
            translated_texts = [combined_translate(text, self.source_lang, self.target_lang, self.default_tokenizer, self.default_model) for text in inputs]
            encoded_outputs = [self.default_tokenizer.encode(t, return_tensors="pt", padding=True, truncation=True, max_length=500)[0] for t in translated_texts]
            if not encoded_outputs:
                return torch.tensor([])
            return torch.stack(encoded_outputs)  # Stack tensors to ensure proper shape
        except Exception:
            return torch.tensor([])  # Suppress error, return empty tensor

# Load appropriate translation model with caching
@st.cache_resource
def load_model(source_lang, target_lang):
    try:
        if source_lang == target_lang:
            return _load_default_model()
        model_key = (source_lang, target_lang)
        tokenizer_model_pair = all_models.get(model_key)
        if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]:
            return tokenizer_model_pair
        for inter in LANGUAGES.keys():
            if inter != source_lang and inter != target_lang:
                pair1 = all_models.get((source_lang, inter))
                pair2 = all_models.get((inter, target_lang))
                if pair1 and pair1[0] and pair1[1] and pair2 and pair2[0] and pair2[1]:
                    return pair1
        default_tokenizer, default_model = _load_default_model()
        return default_tokenizer, CombinedModel(source_lang, target_lang, default_tokenizer, default_model)
    except Exception:
        raise  # Allow higher-level handling if needed

# Load default translation model with caching
@st.cache_resource
def _load_default_model():
    try:
        model_name = "Helsinki-NLP/opus-mt-en-hi"
        tokenizer = MarianTokenizer.from_pretrained(model_name)
        model = MarianMTModel.from_pretrained(model_name)
        return tokenizer, model
    except Exception:
        raise  # Allow higher-level handling if needed

# Translate text with caching
@st.cache_data
def translate(text, source_lang, target_lang):
    try:
        if not text:
            return ""
        tokenizer, model = load_model(source_lang, target_lang)
        inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
        if inputs['input_ids'].size(0) > 1:
            inputs = {k: v[0].unsqueeze(0) for k, v in inputs.items()}
        with torch.no_grad():
            translated_ids = model.generate(**inputs, max_length=1000 if target_lang in ["hi", "zh", "ja"] else 500, num_beams=4, early_stopping=True)
        result = tokenizer.decode(translated_ids[0], skip_special_tokens=True) if translated_ids.size(0) > 0 else None
        if result and result.strip():
            return result
        else:
            st.warning("This translation is not possible at this moment. Please try another language.")
            return text  # Return source text without additional note
    except Exception:
        st.warning("This translation is not possible at this moment. Please try another language.")
        return text  # Return source text without additional note