|
import gradio as gr |
|
import torch |
|
from transformers import MBartForConditionalGeneration, MBartTokenizer, pipeline |
|
from langdetect import detect |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
model_name = "facebook/mbart-large-50-many-to-many-mmt" |
|
tokenizer = MBartTokenizer.from_pretrained(model_name) |
|
model = MBartForConditionalGeneration.from_pretrained(model_name).to(device) |
|
|
|
|
|
nllb_distilled_pipe = pipeline("translation", model="facebook/nllb-200-distilled-600M", device=0 if device == "cuda" else -1) |
|
narsil_nllb_pipe = pipeline("translation", model="Narsil/nllb", device=0 if device == "cuda" else -1) |
|
|
|
|
|
|
|
MBART_LANGUAGE_OPTIONS = { |
|
"Arabic": "ar_AR", "Czech": "cs_CZ", "German": "de_DE", "English": "en_XX", |
|
"Spanish": "es_XX", "Estonian": "et_EE", "Finnish": "fi_FI", "French": "fr_XX", |
|
"Gujarati": "gu_IN", "Hindi": "hi_IN", "Italian": "it_IT", "Japanese": "ja_XX", |
|
"Kazakh": "kk_KZ", "Korean": "ko_KR", "Lithuanian": "lt_LT", "Latvian": "lv_LV", |
|
"Burmese": "my_MM", "Nepali": "ne_NP", "Dutch": "nl_XX", "Romanian": "ro_RO", |
|
"Russian": "ru_RU", "Sinhala": "si_LK", "Turkish": "tr_TR", "Vietnamese": "vi_VN", |
|
"Chinese": "zh_CN", "Afrikaans": "af_ZA", "Azerbaijani": "az_AZ", "Bengali": "bn_IN", |
|
"Persian": "fa_IR", "Hebrew": "he_IL", "Croatian": "hr_HR", "Indonesian": "id_ID", |
|
"Georgian": "ka_GE", "Khmer": "km_KH", "Macedonian": "mk_MK", "Malayalam": "ml_IN", |
|
"Mongolian": "mn_MN", "Marathi": "mr_IN", "Polish": "pl_PL", "Pashto": "ps_AF", |
|
"Portuguese": "pt_XX", "Swedish": "sv_SE", "Swahili": "sw_KE", "Tamil": "ta_IN", |
|
"Telugu": "te_IN", "Thai": "th_TH", "Tagalog": "tl_XX", "Ukrainian": "uk_UA", |
|
"Urdu": "ur_PK", "Xhosa": "xh_ZA", "Galician": "gl_ES", "Slovene": "sl_SI" |
|
} |
|
|
|
|
|
NLLB_DISTILLED_LANGUAGE_OPTIONS = { |
|
"Arabic": "ara_Arab", "Bulgarian": "bul_Cyrl", "Czech": "ces_Latn", "Danish": "dan_Latn", |
|
"German": "deu_Latn", "Greek": "ell_Grek", "English": "eng_Latn", "Finnish": "fin_Latn", |
|
"French": "fra_Latn", "Hindi": "hin_Deva", "Hungarian": "hun_Latn", "Italian": "ita_Latn", |
|
"Japanese": "jpn_Jpan", "Korean": "kor_Hang", "Dutch": "nld_Latn", "Polish": "pol_Latn", |
|
"Portuguese": "por_Latn", "Russian": "rus_Cyrl", "Spanish": "spa_Latn", "Swedish": "swe_Latn", |
|
"Thai": "tha_Thai", "Turkish": "tur_Latn", "Ukrainian": "ukr_Cyrl", "Vietnamese": "vie_Latn", |
|
"Chinese": "zho_Hans" |
|
} |
|
|
|
|
|
NARSIL_NLLB_LANGUAGE_OPTIONS = { |
|
"Amharic": "amh_Ethi", "Arabic": "ara_Arab", "Bengali": "ben_Beng", "Bulgarian": "bul_Cyrl", |
|
"Catalan": "cat_Latn", "Czech": "ces_Latn", "Danish": "dan_Latn", "German": "deu_Latn", |
|
"Greek": "ell_Grek", "English": "eng_Latn", "Finnish": "fin_Latn", "French": "fra_Latn", |
|
"Hebrew": "heb_Hebr", "Hindi": "hin_Deva", "Hungarian": "hun_Latn", "Italian": "ita_Latn", |
|
"Japanese": "jpn_Jpan", "Korean": "kor_Hang", "Marathi": "mar_Deva", "Dutch": "nld_Latn", |
|
"Norwegian": "nob_Latn", "Polish": "pol_Latn", "Portuguese": "por_Latn", "Romanian": "ron_Latn", |
|
"Russian": "rus_Cyrl", "Spanish": "spa_Latn", "Swedish": "swe_Latn", "Tamil": "tam_Taml", |
|
"Telugu": "tel_Telu", "Thai": "tha_Thai", "Turkish": "tur_Latn", "Ukrainian": "ukr_Cyrl", |
|
"Urdu": "urd_Arab", "Vietnamese": "vie_Latn", "Chinese": "zho_Hans" |
|
} |
|
|
|
|
|
LANGDETECT_TO_MBART = { |
|
'ar': 'ar_AR', 'cs': 'cs_CZ', 'de': 'de_DE', 'en': 'en_XX', |
|
'es': 'es_XX', 'et': 'et_EE', 'fi': 'fi_FI', 'fr': 'fr_XX', |
|
'gu': 'gu_IN', 'hi': 'hi_IN', 'it': 'it_IT', 'ja': 'ja_XX', |
|
'kk': 'kk_KZ', 'ko': 'ko_KR', 'lt': 'lt_LT', 'lv': 'lv_LV', |
|
'my': 'my_MM', 'ne': 'ne_NP', 'nl': 'nl_XX', 'ro': 'ro_RO', |
|
'ru': 'ru_RU', 'si': 'si_LK', 'tr': 'tr_TR', 'vi': 'vi_VN', |
|
'zh-cn': 'zh_CN', 'zh': 'zh_CN', 'af': 'af_ZA', 'az': 'az_AZ', |
|
'bn': 'bn_IN', 'fa': 'fa_IR', 'he': 'he_IL', 'hr': 'hr_HR', |
|
'id': 'id_ID', 'ka': 'ka_GE', 'km': 'km_KH', 'mk': 'mk_MK', |
|
'ml': 'ml_IN', 'mn': 'mn_MN', 'mr': 'mr_IN', 'pl': 'pl_PL', |
|
'ps': 'ps_AF', 'pt': 'pt_XX', 'sv': 'sv_SE', 'sw': 'sw_KE', |
|
'ta': 'ta_IN', 'te': 'te_IN', 'th': 'th_TH', 'tl': 'tl_XX', |
|
'uk': 'uk_UA', 'ur': 'ur_PK', 'xh': 'xh_ZA', 'gl': 'gl_ES', |
|
'sl': 'sl_SI' |
|
} |
|
|
|
|
|
LANGDETECT_TO_NLLB_DISTILLED = { |
|
'ar': 'ara_Arab', 'bg': 'bul_Cyrl', 'cs': 'ces_Latn', 'da': 'dan_Latn', |
|
'de': 'deu_Latn', 'el': 'ell_Grek', 'en': 'eng_Latn', 'fi': 'fin_Latn', |
|
'fr': 'fra_Latn', 'hi': 'hin_Deva', 'hu': 'hun_Latn', 'it': 'ita_Latn', |
|
'ja': 'jpn_Jpan', 'ko': 'kor_Hang', 'nl': 'nld_Latn', 'pl': 'pol_Latn', |
|
'pt': 'por_Latn', 'ru': 'rus_Cyrl', 'es': 'spa_Latn', 'sv': 'swe_Latn', |
|
'th': 'tha_Thai', 'tr': 'tur_Latn', 'uk': 'ukr_Cyrl', 'vi': 'vie_Latn', |
|
'zh': 'zho_Hans', 'zh-cn': 'zho_Hans' |
|
} |
|
|
|
LANGDETECT_TO_NARSIL_NLLB = { |
|
'am': 'amh_Ethi', 'ar': 'ara_Arab', 'bn': 'ben_Beng', 'bg': 'bul_Cyrl', |
|
'ca': 'cat_Latn', 'cs': 'ces_Latn', 'da': 'dan_Latn', 'de': 'deu_Latn', |
|
'el': 'ell_Grek', 'en': 'eng_Latn', 'fi': 'fin_Latn', 'fr': 'fra_Latn', |
|
'he': 'heb_Hebr', 'hi': 'hin_Deva', 'hu': 'hun_Latn', 'it': 'ita_Latn', |
|
'ja': 'jpn_Jpan', 'ko': 'kor_Hang', 'mr': 'mar_Deva', 'nl': 'nld_Latn', |
|
'no': 'nob_Latn', 'pl': 'pol_Latn', 'pt': 'por_Latn', 'ro': 'ron_Latn', |
|
'ru': 'rus_Cyrl', 'es': 'spa_Latn', 'sv': 'swe_Latn', 'ta': 'tam_Taml', |
|
'te': 'tel_Telu', 'th': 'tha_Thai', 'tr': 'tur_Latn', 'uk': 'ukr_Cyrl', |
|
'ur': 'urd_Arab', 'vi': 'vie_Latn', 'zh': 'zho_Hans', 'zh-cn': 'zho_Hans' |
|
} |
|
|
|
def translate_mbart(text, source_lang, target_lang): |
|
if not text: |
|
return "Please enter text to translate." |
|
|
|
|
|
if source_lang == "Auto-detect": |
|
try: |
|
detected_lang = detect(text) |
|
if detected_lang in LANGDETECT_TO_MBART: |
|
src_lang_code = LANGDETECT_TO_MBART[detected_lang] |
|
source_lang_display = f"Auto-detected: {[k for k, v in MBART_LANGUAGE_OPTIONS.items() if v == src_lang_code][0]}" |
|
else: |
|
return f"Detected language '{detected_lang}' is not supported by MBart." |
|
except: |
|
return "Could not detect language. Please select a source language manually." |
|
else: |
|
if source_lang not in MBART_LANGUAGE_OPTIONS: |
|
return f"Language '{source_lang}' is not supported by MBart." |
|
src_lang_code = MBART_LANGUAGE_OPTIONS[source_lang] |
|
source_lang_display = source_lang |
|
|
|
if target_lang not in MBART_LANGUAGE_OPTIONS: |
|
return f"Target language '{target_lang}' is not supported by MBart." |
|
|
|
tgt_lang_code = MBART_LANGUAGE_OPTIONS[target_lang] |
|
|
|
|
|
tokenizer.src_lang = src_lang_code |
|
|
|
try: |
|
|
|
encoded = tokenizer(text, return_tensors="pt").to(device) |
|
|
|
|
|
generated_tokens = model.generate( |
|
**encoded, |
|
forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang_code], |
|
max_length=1024, |
|
) |
|
|
|
|
|
translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] |
|
|
|
return f"Source Language: {source_lang_display}\nTranslation ({target_lang}):\n\n{translation}" |
|
|
|
except Exception as e: |
|
return f"An error occurred during MBart translation: {str(e)}" |
|
|
|
def translate_nllb_distilled(text, source_lang, target_lang): |
|
if not text: |
|
return "Please enter text to translate." |
|
|
|
|
|
if source_lang == "Auto-detect": |
|
try: |
|
detected_lang = detect(text) |
|
if detected_lang in LANGDETECT_TO_NLLB_DISTILLED: |
|
src_lang_code = LANGDETECT_TO_NLLB_DISTILLED[detected_lang] |
|
source_lang_display = f"Auto-detected: {[k for k, v in NLLB_DISTILLED_LANGUAGE_OPTIONS.items() if v == src_lang_code][0]}" |
|
else: |
|
return f"Detected language '{detected_lang}' is not supported by NLLB Distilled." |
|
except: |
|
return "Could not detect language. Please select a source language manually." |
|
else: |
|
if source_lang not in NLLB_DISTILLED_LANGUAGE_OPTIONS: |
|
return f"Language '{source_lang}' is not supported by NLLB Distilled." |
|
src_lang_code = NLLB_DISTILLED_LANGUAGE_OPTIONS[source_lang] |
|
source_lang_display = source_lang |
|
|
|
if target_lang not in NLLB_DISTILLED_LANGUAGE_OPTIONS: |
|
return f"Target language '{target_lang}' is not supported by NLLB Distilled." |
|
|
|
tgt_lang_code = NLLB_DISTILLED_LANGUAGE_OPTIONS[target_lang] |
|
|
|
try: |
|
|
|
result = nllb_distilled_pipe(text, src_lang=src_lang_code, tgt_lang=tgt_lang_code) |
|
translation = result[0]['translation_text'] |
|
|
|
return f"Source Language: {source_lang_display}\nTranslation ({target_lang}):\n\n{translation}" |
|
|
|
except Exception as e: |
|
return f"An error occurred during NLLB Distilled translation: {str(e)}" |
|
|
|
def translate_narsil_nllb(text, source_lang, target_lang): |
|
if not text: |
|
return "Please enter text to translate." |
|
|
|
|
|
if source_lang == "Auto-detect": |
|
try: |
|
detected_lang = detect(text) |
|
if detected_lang in LANGDETECT_TO_NARSIL_NLLB: |
|
src_lang_code = LANGDETECT_TO_NARSIL_NLLB[detected_lang] |
|
source_lang_display = f"Auto-detected: {[k for k, v in NARSIL_NLLB_LANGUAGE_OPTIONS.items() if v == src_lang_code][0]}" |
|
else: |
|
return f"Detected language '{detected_lang}' is not supported by Narsil/NLLB." |
|
except: |
|
return "Could not detect language. Please select a source language manually." |
|
else: |
|
if source_lang not in NARSIL_NLLB_LANGUAGE_OPTIONS: |
|
return f"Language '{source_lang}' is not supported by Narsil/NLLB." |
|
src_lang_code = NARSIL_NLLB_LANGUAGE_OPTIONS[source_lang] |
|
source_lang_display = source_lang |
|
|
|
if target_lang not in NARSIL_NLLB_LANGUAGE_OPTIONS: |
|
return f"Target language '{target_lang}' is not supported by Narsil/NLLB." |
|
|
|
tgt_lang_code = NARSIL_NLLB_LANGUAGE_OPTIONS[target_lang] |
|
|
|
try: |
|
|
|
result = narsil_nllb_pipe(text, src_lang=src_lang_code, tgt_lang=tgt_lang_code) |
|
translation = result[0]['translation_text'] |
|
|
|
return f"Source Language: {source_lang_display}\nTranslation ({target_lang}):\n\n{translation}" |
|
|
|
except Exception as e: |
|
return f"An error occurred during Narsil/NLLB translation: {str(e)}" |
|
|
|
def translate_all(text, source_lang, target_lang): |
|
|
|
mbart_result = translate_mbart(text, source_lang, target_lang) |
|
nllb_distilled_result = translate_nllb_distilled(text, source_lang, target_lang) |
|
narsil_nllb_result = translate_narsil_nllb(text, source_lang, target_lang) |
|
|
|
return mbart_result, nllb_distilled_result, narsil_nllb_result |
|
|
|
|
|
all_languages = sorted(list(set(MBART_LANGUAGE_OPTIONS.keys()) | |
|
set(NLLB_DISTILLED_LANGUAGE_OPTIONS.keys()) | |
|
set(NARSIL_NLLB_LANGUAGE_OPTIONS.keys()))) |
|
|
|
|
|
source_languages = ["Auto-detect"] + all_languages |
|
target_languages = all_languages |
|
|
|
with gr.Blocks(title="Multi-Model Translation") as app: |
|
gr.Markdown("# Multilingual Translation System") |
|
gr.Markdown("Enter text to translate. If source language is not specified, it will be auto-detected.") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_text = gr.Textbox(label="Text to translate", lines=10, placeholder="Enter text here...") |
|
|
|
with gr.Row(): |
|
source_lang = gr.Dropdown( |
|
choices=source_languages, |
|
value="Auto-detect", |
|
label="Source Language" |
|
) |
|
|
|
target_lang = gr.Dropdown( |
|
choices=target_languages, |
|
value="English", |
|
label="Target Language" |
|
) |
|
|
|
translate_btn = gr.Button("Translate") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("### MBart Translation") |
|
mbart_output = gr.Textbox(label="MBart Translation Output", lines=10) |
|
|
|
with gr.Column(): |
|
gr.Markdown("### NLLB Distilled Translation") |
|
nllb_distilled_output = gr.Textbox(label="NLLB Distilled Translation Output", lines=10) |
|
|
|
with gr.Column(): |
|
gr.Markdown("### Narsil/NLLB Translation") |
|
narsil_nllb_output = gr.Textbox(label="Narsil/NLLB Translation Output", lines=10) |
|
|
|
translate_btn.click( |
|
fn=translate_all, |
|
inputs=[input_text, source_lang, target_lang], |
|
outputs=[mbart_output, nllb_distilled_output, narsil_nllb_output] |
|
) |
|
|
|
gr.Markdown("### Supported Languages") |
|
gr.Markdown("Note: Not all languages are supported by all models. If a language is not supported by a model, it will show an error message.") |
|
|
|
with gr.Accordion("MBart Supported Languages", open=False): |
|
gr.Markdown(", ".join(sorted(MBART_LANGUAGE_OPTIONS.keys()))) |
|
|
|
with gr.Accordion("NLLB Distilled Supported Languages", open=False): |
|
gr.Markdown(", ".join(sorted(NLLB_DISTILLED_LANGUAGE_OPTIONS.keys()))) |
|
|
|
with gr.Accordion("Narsil/NLLB Supported Languages", open=False): |
|
gr.Markdown(", ".join(sorted(NARSIL_NLLB_LANGUAGE_OPTIONS.keys()))) |
|
|
|
|
|
if __name__ == "__main__": |
|
app.launch(debug=True) |