Chan-Y's picture
Create app.py
c8018e2 verified
import gradio as gr
import torch
from transformers import MBartForConditionalGeneration, MBartTokenizer, pipeline
from langdetect import detect
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load MBart model and tokenizer
model_name = "facebook/mbart-large-50-many-to-many-mmt"
tokenizer = MBartTokenizer.from_pretrained(model_name)
model = MBartForConditionalGeneration.from_pretrained(model_name).to(device)
# Load NLLB models
nllb_distilled_pipe = pipeline("translation", model="facebook/nllb-200-distilled-600M", device=0 if device == "cuda" else -1)
narsil_nllb_pipe = pipeline("translation", model="Narsil/nllb", device=0 if device == "cuda" else -1)
# Create dictionaries mapping language names to language codes for each model
# MBart languages
MBART_LANGUAGE_OPTIONS = {
"Arabic": "ar_AR", "Czech": "cs_CZ", "German": "de_DE", "English": "en_XX",
"Spanish": "es_XX", "Estonian": "et_EE", "Finnish": "fi_FI", "French": "fr_XX",
"Gujarati": "gu_IN", "Hindi": "hi_IN", "Italian": "it_IT", "Japanese": "ja_XX",
"Kazakh": "kk_KZ", "Korean": "ko_KR", "Lithuanian": "lt_LT", "Latvian": "lv_LV",
"Burmese": "my_MM", "Nepali": "ne_NP", "Dutch": "nl_XX", "Romanian": "ro_RO",
"Russian": "ru_RU", "Sinhala": "si_LK", "Turkish": "tr_TR", "Vietnamese": "vi_VN",
"Chinese": "zh_CN", "Afrikaans": "af_ZA", "Azerbaijani": "az_AZ", "Bengali": "bn_IN",
"Persian": "fa_IR", "Hebrew": "he_IL", "Croatian": "hr_HR", "Indonesian": "id_ID",
"Georgian": "ka_GE", "Khmer": "km_KH", "Macedonian": "mk_MK", "Malayalam": "ml_IN",
"Mongolian": "mn_MN", "Marathi": "mr_IN", "Polish": "pl_PL", "Pashto": "ps_AF",
"Portuguese": "pt_XX", "Swedish": "sv_SE", "Swahili": "sw_KE", "Tamil": "ta_IN",
"Telugu": "te_IN", "Thai": "th_TH", "Tagalog": "tl_XX", "Ukrainian": "uk_UA",
"Urdu": "ur_PK", "Xhosa": "xh_ZA", "Galician": "gl_ES", "Slovene": "sl_SI"
}
# NLLB Distilled language codes
NLLB_DISTILLED_LANGUAGE_OPTIONS = {
"Arabic": "ara_Arab", "Bulgarian": "bul_Cyrl", "Czech": "ces_Latn", "Danish": "dan_Latn",
"German": "deu_Latn", "Greek": "ell_Grek", "English": "eng_Latn", "Finnish": "fin_Latn",
"French": "fra_Latn", "Hindi": "hin_Deva", "Hungarian": "hun_Latn", "Italian": "ita_Latn",
"Japanese": "jpn_Jpan", "Korean": "kor_Hang", "Dutch": "nld_Latn", "Polish": "pol_Latn",
"Portuguese": "por_Latn", "Russian": "rus_Cyrl", "Spanish": "spa_Latn", "Swedish": "swe_Latn",
"Thai": "tha_Thai", "Turkish": "tur_Latn", "Ukrainian": "ukr_Cyrl", "Vietnamese": "vie_Latn",
"Chinese": "zho_Hans"
}
# Narsil/nllb language codes
NARSIL_NLLB_LANGUAGE_OPTIONS = {
"Amharic": "amh_Ethi", "Arabic": "ara_Arab", "Bengali": "ben_Beng", "Bulgarian": "bul_Cyrl",
"Catalan": "cat_Latn", "Czech": "ces_Latn", "Danish": "dan_Latn", "German": "deu_Latn",
"Greek": "ell_Grek", "English": "eng_Latn", "Finnish": "fin_Latn", "French": "fra_Latn",
"Hebrew": "heb_Hebr", "Hindi": "hin_Deva", "Hungarian": "hun_Latn", "Italian": "ita_Latn",
"Japanese": "jpn_Jpan", "Korean": "kor_Hang", "Marathi": "mar_Deva", "Dutch": "nld_Latn",
"Norwegian": "nob_Latn", "Polish": "pol_Latn", "Portuguese": "por_Latn", "Romanian": "ron_Latn",
"Russian": "rus_Cyrl", "Spanish": "spa_Latn", "Swedish": "swe_Latn", "Tamil": "tam_Taml",
"Telugu": "tel_Telu", "Thai": "tha_Thai", "Turkish": "tur_Latn", "Ukrainian": "ukr_Cyrl",
"Urdu": "urd_Arab", "Vietnamese": "vie_Latn", "Chinese": "zho_Hans"
}
# Map from langdetect codes to model-specific codes
LANGDETECT_TO_MBART = {
'ar': 'ar_AR', 'cs': 'cs_CZ', 'de': 'de_DE', 'en': 'en_XX',
'es': 'es_XX', 'et': 'et_EE', 'fi': 'fi_FI', 'fr': 'fr_XX',
'gu': 'gu_IN', 'hi': 'hi_IN', 'it': 'it_IT', 'ja': 'ja_XX',
'kk': 'kk_KZ', 'ko': 'ko_KR', 'lt': 'lt_LT', 'lv': 'lv_LV',
'my': 'my_MM', 'ne': 'ne_NP', 'nl': 'nl_XX', 'ro': 'ro_RO',
'ru': 'ru_RU', 'si': 'si_LK', 'tr': 'tr_TR', 'vi': 'vi_VN',
'zh-cn': 'zh_CN', 'zh': 'zh_CN', 'af': 'af_ZA', 'az': 'az_AZ',
'bn': 'bn_IN', 'fa': 'fa_IR', 'he': 'he_IL', 'hr': 'hr_HR',
'id': 'id_ID', 'ka': 'ka_GE', 'km': 'km_KH', 'mk': 'mk_MK',
'ml': 'ml_IN', 'mn': 'mn_MN', 'mr': 'mr_IN', 'pl': 'pl_PL',
'ps': 'ps_AF', 'pt': 'pt_XX', 'sv': 'sv_SE', 'sw': 'sw_KE',
'ta': 'ta_IN', 'te': 'te_IN', 'th': 'th_TH', 'tl': 'tl_XX',
'uk': 'uk_UA', 'ur': 'ur_PK', 'xh': 'xh_ZA', 'gl': 'gl_ES',
'sl': 'sl_SI'
}
# Create mappings from langdetect codes to NLLB codes
LANGDETECT_TO_NLLB_DISTILLED = {
'ar': 'ara_Arab', 'bg': 'bul_Cyrl', 'cs': 'ces_Latn', 'da': 'dan_Latn',
'de': 'deu_Latn', 'el': 'ell_Grek', 'en': 'eng_Latn', 'fi': 'fin_Latn',
'fr': 'fra_Latn', 'hi': 'hin_Deva', 'hu': 'hun_Latn', 'it': 'ita_Latn',
'ja': 'jpn_Jpan', 'ko': 'kor_Hang', 'nl': 'nld_Latn', 'pl': 'pol_Latn',
'pt': 'por_Latn', 'ru': 'rus_Cyrl', 'es': 'spa_Latn', 'sv': 'swe_Latn',
'th': 'tha_Thai', 'tr': 'tur_Latn', 'uk': 'ukr_Cyrl', 'vi': 'vie_Latn',
'zh': 'zho_Hans', 'zh-cn': 'zho_Hans'
}
LANGDETECT_TO_NARSIL_NLLB = {
'am': 'amh_Ethi', 'ar': 'ara_Arab', 'bn': 'ben_Beng', 'bg': 'bul_Cyrl',
'ca': 'cat_Latn', 'cs': 'ces_Latn', 'da': 'dan_Latn', 'de': 'deu_Latn',
'el': 'ell_Grek', 'en': 'eng_Latn', 'fi': 'fin_Latn', 'fr': 'fra_Latn',
'he': 'heb_Hebr', 'hi': 'hin_Deva', 'hu': 'hun_Latn', 'it': 'ita_Latn',
'ja': 'jpn_Jpan', 'ko': 'kor_Hang', 'mr': 'mar_Deva', 'nl': 'nld_Latn',
'no': 'nob_Latn', 'pl': 'pol_Latn', 'pt': 'por_Latn', 'ro': 'ron_Latn',
'ru': 'rus_Cyrl', 'es': 'spa_Latn', 'sv': 'swe_Latn', 'ta': 'tam_Taml',
'te': 'tel_Telu', 'th': 'tha_Thai', 'tr': 'tur_Latn', 'uk': 'ukr_Cyrl',
'ur': 'urd_Arab', 'vi': 'vie_Latn', 'zh': 'zho_Hans', 'zh-cn': 'zho_Hans'
}
def translate_mbart(text, source_lang, target_lang):
if not text:
return "Please enter text to translate."
# If source language is not specified, detect it
if source_lang == "Auto-detect":
try:
detected_lang = detect(text)
if detected_lang in LANGDETECT_TO_MBART:
src_lang_code = LANGDETECT_TO_MBART[detected_lang]
source_lang_display = f"Auto-detected: {[k for k, v in MBART_LANGUAGE_OPTIONS.items() if v == src_lang_code][0]}"
else:
return f"Detected language '{detected_lang}' is not supported by MBart."
except:
return "Could not detect language. Please select a source language manually."
else:
if source_lang not in MBART_LANGUAGE_OPTIONS:
return f"Language '{source_lang}' is not supported by MBart."
src_lang_code = MBART_LANGUAGE_OPTIONS[source_lang]
source_lang_display = source_lang
if target_lang not in MBART_LANGUAGE_OPTIONS:
return f"Target language '{target_lang}' is not supported by MBart."
tgt_lang_code = MBART_LANGUAGE_OPTIONS[target_lang]
# Set the source language
tokenizer.src_lang = src_lang_code
try:
# Tokenize the input text
encoded = tokenizer(text, return_tensors="pt").to(device)
# Generate translation
generated_tokens = model.generate(
**encoded,
forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang_code],
max_length=1024,
)
# Decode the generated tokens
translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
return f"Source Language: {source_lang_display}\nTranslation ({target_lang}):\n\n{translation}"
except Exception as e:
return f"An error occurred during MBart translation: {str(e)}"
def translate_nllb_distilled(text, source_lang, target_lang):
if not text:
return "Please enter text to translate."
# If source language is not specified, detect it
if source_lang == "Auto-detect":
try:
detected_lang = detect(text)
if detected_lang in LANGDETECT_TO_NLLB_DISTILLED:
src_lang_code = LANGDETECT_TO_NLLB_DISTILLED[detected_lang]
source_lang_display = f"Auto-detected: {[k for k, v in NLLB_DISTILLED_LANGUAGE_OPTIONS.items() if v == src_lang_code][0]}"
else:
return f"Detected language '{detected_lang}' is not supported by NLLB Distilled."
except:
return "Could not detect language. Please select a source language manually."
else:
if source_lang not in NLLB_DISTILLED_LANGUAGE_OPTIONS:
return f"Language '{source_lang}' is not supported by NLLB Distilled."
src_lang_code = NLLB_DISTILLED_LANGUAGE_OPTIONS[source_lang]
source_lang_display = source_lang
if target_lang not in NLLB_DISTILLED_LANGUAGE_OPTIONS:
return f"Target language '{target_lang}' is not supported by NLLB Distilled."
tgt_lang_code = NLLB_DISTILLED_LANGUAGE_OPTIONS[target_lang]
try:
# Translate using NLLB Distilled pipeline
result = nllb_distilled_pipe(text, src_lang=src_lang_code, tgt_lang=tgt_lang_code)
translation = result[0]['translation_text']
return f"Source Language: {source_lang_display}\nTranslation ({target_lang}):\n\n{translation}"
except Exception as e:
return f"An error occurred during NLLB Distilled translation: {str(e)}"
def translate_narsil_nllb(text, source_lang, target_lang):
if not text:
return "Please enter text to translate."
# If source language is not specified, detect it
if source_lang == "Auto-detect":
try:
detected_lang = detect(text)
if detected_lang in LANGDETECT_TO_NARSIL_NLLB:
src_lang_code = LANGDETECT_TO_NARSIL_NLLB[detected_lang]
source_lang_display = f"Auto-detected: {[k for k, v in NARSIL_NLLB_LANGUAGE_OPTIONS.items() if v == src_lang_code][0]}"
else:
return f"Detected language '{detected_lang}' is not supported by Narsil/NLLB."
except:
return "Could not detect language. Please select a source language manually."
else:
if source_lang not in NARSIL_NLLB_LANGUAGE_OPTIONS:
return f"Language '{source_lang}' is not supported by Narsil/NLLB."
src_lang_code = NARSIL_NLLB_LANGUAGE_OPTIONS[source_lang]
source_lang_display = source_lang
if target_lang not in NARSIL_NLLB_LANGUAGE_OPTIONS:
return f"Target language '{target_lang}' is not supported by Narsil/NLLB."
tgt_lang_code = NARSIL_NLLB_LANGUAGE_OPTIONS[target_lang]
try:
# Translate using Narsil/NLLB pipeline
result = narsil_nllb_pipe(text, src_lang=src_lang_code, tgt_lang=tgt_lang_code)
translation = result[0]['translation_text']
return f"Source Language: {source_lang_display}\nTranslation ({target_lang}):\n\n{translation}"
except Exception as e:
return f"An error occurred during Narsil/NLLB translation: {str(e)}"
def translate_all(text, source_lang, target_lang):
# Call all translation functions
mbart_result = translate_mbart(text, source_lang, target_lang)
nllb_distilled_result = translate_nllb_distilled(text, source_lang, target_lang)
narsil_nllb_result = translate_narsil_nllb(text, source_lang, target_lang)
return mbart_result, nllb_distilled_result, narsil_nllb_result
# Get all languages supported by at least one model
all_languages = sorted(list(set(MBART_LANGUAGE_OPTIONS.keys()) |
set(NLLB_DISTILLED_LANGUAGE_OPTIONS.keys()) |
set(NARSIL_NLLB_LANGUAGE_OPTIONS.keys())))
# Create the Gradio interface
source_languages = ["Auto-detect"] + all_languages
target_languages = all_languages
with gr.Blocks(title="Multi-Model Translation") as app:
gr.Markdown("# Multilingual Translation System")
gr.Markdown("Enter text to translate. If source language is not specified, it will be auto-detected.")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(label="Text to translate", lines=10, placeholder="Enter text here...")
with gr.Row():
source_lang = gr.Dropdown(
choices=source_languages,
value="Auto-detect",
label="Source Language"
)
target_lang = gr.Dropdown(
choices=target_languages,
value="English",
label="Target Language"
)
translate_btn = gr.Button("Translate")
with gr.Row():
with gr.Column():
gr.Markdown("### MBart Translation")
mbart_output = gr.Textbox(label="MBart Translation Output", lines=10)
with gr.Column():
gr.Markdown("### NLLB Distilled Translation")
nllb_distilled_output = gr.Textbox(label="NLLB Distilled Translation Output", lines=10)
with gr.Column():
gr.Markdown("### Narsil/NLLB Translation")
narsil_nllb_output = gr.Textbox(label="Narsil/NLLB Translation Output", lines=10)
translate_btn.click(
fn=translate_all,
inputs=[input_text, source_lang, target_lang],
outputs=[mbart_output, nllb_distilled_output, narsil_nllb_output]
)
gr.Markdown("### Supported Languages")
gr.Markdown("Note: Not all languages are supported by all models. If a language is not supported by a model, it will show an error message.")
with gr.Accordion("MBart Supported Languages", open=False):
gr.Markdown(", ".join(sorted(MBART_LANGUAGE_OPTIONS.keys())))
with gr.Accordion("NLLB Distilled Supported Languages", open=False):
gr.Markdown(", ".join(sorted(NLLB_DISTILLED_LANGUAGE_OPTIONS.keys())))
with gr.Accordion("Narsil/NLLB Supported Languages", open=False):
gr.Markdown(", ".join(sorted(NARSIL_NLLB_LANGUAGE_OPTIONS.keys())))
# Launch the app
if __name__ == "__main__":
app.launch(debug=True)