import gradio as gr import torch from transformers import MBartForConditionalGeneration, MBartTokenizer, pipeline from langdetect import detect device = "cuda" if torch.cuda.is_available() else "cpu" # Load MBart model and tokenizer model_name = "facebook/mbart-large-50-many-to-many-mmt" tokenizer = MBartTokenizer.from_pretrained(model_name) model = MBartForConditionalGeneration.from_pretrained(model_name).to(device) # Load NLLB models nllb_distilled_pipe = pipeline("translation", model="facebook/nllb-200-distilled-600M", device=0 if device == "cuda" else -1) narsil_nllb_pipe = pipeline("translation", model="Narsil/nllb", device=0 if device == "cuda" else -1) # Create dictionaries mapping language names to language codes for each model # MBart languages MBART_LANGUAGE_OPTIONS = { "Arabic": "ar_AR", "Czech": "cs_CZ", "German": "de_DE", "English": "en_XX", "Spanish": "es_XX", "Estonian": "et_EE", "Finnish": "fi_FI", "French": "fr_XX", "Gujarati": "gu_IN", "Hindi": "hi_IN", "Italian": "it_IT", "Japanese": "ja_XX", "Kazakh": "kk_KZ", "Korean": "ko_KR", "Lithuanian": "lt_LT", "Latvian": "lv_LV", "Burmese": "my_MM", "Nepali": "ne_NP", "Dutch": "nl_XX", "Romanian": "ro_RO", "Russian": "ru_RU", "Sinhala": "si_LK", "Turkish": "tr_TR", "Vietnamese": "vi_VN", "Chinese": "zh_CN", "Afrikaans": "af_ZA", "Azerbaijani": "az_AZ", "Bengali": "bn_IN", "Persian": "fa_IR", "Hebrew": "he_IL", "Croatian": "hr_HR", "Indonesian": "id_ID", "Georgian": "ka_GE", "Khmer": "km_KH", "Macedonian": "mk_MK", "Malayalam": "ml_IN", "Mongolian": "mn_MN", "Marathi": "mr_IN", "Polish": "pl_PL", "Pashto": "ps_AF", "Portuguese": "pt_XX", "Swedish": "sv_SE", "Swahili": "sw_KE", "Tamil": "ta_IN", "Telugu": "te_IN", "Thai": "th_TH", "Tagalog": "tl_XX", "Ukrainian": "uk_UA", "Urdu": "ur_PK", "Xhosa": "xh_ZA", "Galician": "gl_ES", "Slovene": "sl_SI" } # NLLB Distilled language codes NLLB_DISTILLED_LANGUAGE_OPTIONS = { "Arabic": "ara_Arab", "Bulgarian": "bul_Cyrl", "Czech": "ces_Latn", "Danish": "dan_Latn", "German": "deu_Latn", "Greek": "ell_Grek", "English": "eng_Latn", "Finnish": "fin_Latn", "French": "fra_Latn", "Hindi": "hin_Deva", "Hungarian": "hun_Latn", "Italian": "ita_Latn", "Japanese": "jpn_Jpan", "Korean": "kor_Hang", "Dutch": "nld_Latn", "Polish": "pol_Latn", "Portuguese": "por_Latn", "Russian": "rus_Cyrl", "Spanish": "spa_Latn", "Swedish": "swe_Latn", "Thai": "tha_Thai", "Turkish": "tur_Latn", "Ukrainian": "ukr_Cyrl", "Vietnamese": "vie_Latn", "Chinese": "zho_Hans" } # Narsil/nllb language codes NARSIL_NLLB_LANGUAGE_OPTIONS = { "Amharic": "amh_Ethi", "Arabic": "ara_Arab", "Bengali": "ben_Beng", "Bulgarian": "bul_Cyrl", "Catalan": "cat_Latn", "Czech": "ces_Latn", "Danish": "dan_Latn", "German": "deu_Latn", "Greek": "ell_Grek", "English": "eng_Latn", "Finnish": "fin_Latn", "French": "fra_Latn", "Hebrew": "heb_Hebr", "Hindi": "hin_Deva", "Hungarian": "hun_Latn", "Italian": "ita_Latn", "Japanese": "jpn_Jpan", "Korean": "kor_Hang", "Marathi": "mar_Deva", "Dutch": "nld_Latn", "Norwegian": "nob_Latn", "Polish": "pol_Latn", "Portuguese": "por_Latn", "Romanian": "ron_Latn", "Russian": "rus_Cyrl", "Spanish": "spa_Latn", "Swedish": "swe_Latn", "Tamil": "tam_Taml", "Telugu": "tel_Telu", "Thai": "tha_Thai", "Turkish": "tur_Latn", "Ukrainian": "ukr_Cyrl", "Urdu": "urd_Arab", "Vietnamese": "vie_Latn", "Chinese": "zho_Hans" } # Map from langdetect codes to model-specific codes LANGDETECT_TO_MBART = { 'ar': 'ar_AR', 'cs': 'cs_CZ', 'de': 'de_DE', 'en': 'en_XX', 'es': 'es_XX', 'et': 'et_EE', 'fi': 'fi_FI', 'fr': 'fr_XX', 'gu': 'gu_IN', 'hi': 'hi_IN', 'it': 'it_IT', 'ja': 'ja_XX', 'kk': 'kk_KZ', 'ko': 'ko_KR', 'lt': 'lt_LT', 'lv': 'lv_LV', 'my': 'my_MM', 'ne': 'ne_NP', 'nl': 'nl_XX', 'ro': 'ro_RO', 'ru': 'ru_RU', 'si': 'si_LK', 'tr': 'tr_TR', 'vi': 'vi_VN', 'zh-cn': 'zh_CN', 'zh': 'zh_CN', 'af': 'af_ZA', 'az': 'az_AZ', 'bn': 'bn_IN', 'fa': 'fa_IR', 'he': 'he_IL', 'hr': 'hr_HR', 'id': 'id_ID', 'ka': 'ka_GE', 'km': 'km_KH', 'mk': 'mk_MK', 'ml': 'ml_IN', 'mn': 'mn_MN', 'mr': 'mr_IN', 'pl': 'pl_PL', 'ps': 'ps_AF', 'pt': 'pt_XX', 'sv': 'sv_SE', 'sw': 'sw_KE', 'ta': 'ta_IN', 'te': 'te_IN', 'th': 'th_TH', 'tl': 'tl_XX', 'uk': 'uk_UA', 'ur': 'ur_PK', 'xh': 'xh_ZA', 'gl': 'gl_ES', 'sl': 'sl_SI' } # Create mappings from langdetect codes to NLLB codes LANGDETECT_TO_NLLB_DISTILLED = { 'ar': 'ara_Arab', 'bg': 'bul_Cyrl', 'cs': 'ces_Latn', 'da': 'dan_Latn', 'de': 'deu_Latn', 'el': 'ell_Grek', 'en': 'eng_Latn', 'fi': 'fin_Latn', 'fr': 'fra_Latn', 'hi': 'hin_Deva', 'hu': 'hun_Latn', 'it': 'ita_Latn', 'ja': 'jpn_Jpan', 'ko': 'kor_Hang', 'nl': 'nld_Latn', 'pl': 'pol_Latn', 'pt': 'por_Latn', 'ru': 'rus_Cyrl', 'es': 'spa_Latn', 'sv': 'swe_Latn', 'th': 'tha_Thai', 'tr': 'tur_Latn', 'uk': 'ukr_Cyrl', 'vi': 'vie_Latn', 'zh': 'zho_Hans', 'zh-cn': 'zho_Hans' } LANGDETECT_TO_NARSIL_NLLB = { 'am': 'amh_Ethi', 'ar': 'ara_Arab', 'bn': 'ben_Beng', 'bg': 'bul_Cyrl', 'ca': 'cat_Latn', 'cs': 'ces_Latn', 'da': 'dan_Latn', 'de': 'deu_Latn', 'el': 'ell_Grek', 'en': 'eng_Latn', 'fi': 'fin_Latn', 'fr': 'fra_Latn', 'he': 'heb_Hebr', 'hi': 'hin_Deva', 'hu': 'hun_Latn', 'it': 'ita_Latn', 'ja': 'jpn_Jpan', 'ko': 'kor_Hang', 'mr': 'mar_Deva', 'nl': 'nld_Latn', 'no': 'nob_Latn', 'pl': 'pol_Latn', 'pt': 'por_Latn', 'ro': 'ron_Latn', 'ru': 'rus_Cyrl', 'es': 'spa_Latn', 'sv': 'swe_Latn', 'ta': 'tam_Taml', 'te': 'tel_Telu', 'th': 'tha_Thai', 'tr': 'tur_Latn', 'uk': 'ukr_Cyrl', 'ur': 'urd_Arab', 'vi': 'vie_Latn', 'zh': 'zho_Hans', 'zh-cn': 'zho_Hans' } def translate_mbart(text, source_lang, target_lang): if not text: return "Please enter text to translate." # If source language is not specified, detect it if source_lang == "Auto-detect": try: detected_lang = detect(text) if detected_lang in LANGDETECT_TO_MBART: src_lang_code = LANGDETECT_TO_MBART[detected_lang] source_lang_display = f"Auto-detected: {[k for k, v in MBART_LANGUAGE_OPTIONS.items() if v == src_lang_code][0]}" else: return f"Detected language '{detected_lang}' is not supported by MBart." except: return "Could not detect language. Please select a source language manually." else: if source_lang not in MBART_LANGUAGE_OPTIONS: return f"Language '{source_lang}' is not supported by MBart." src_lang_code = MBART_LANGUAGE_OPTIONS[source_lang] source_lang_display = source_lang if target_lang not in MBART_LANGUAGE_OPTIONS: return f"Target language '{target_lang}' is not supported by MBart." tgt_lang_code = MBART_LANGUAGE_OPTIONS[target_lang] # Set the source language tokenizer.src_lang = src_lang_code try: # Tokenize the input text encoded = tokenizer(text, return_tensors="pt").to(device) # Generate translation generated_tokens = model.generate( **encoded, forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang_code], max_length=1024, ) # Decode the generated tokens translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] return f"Source Language: {source_lang_display}\nTranslation ({target_lang}):\n\n{translation}" except Exception as e: return f"An error occurred during MBart translation: {str(e)}" def translate_nllb_distilled(text, source_lang, target_lang): if not text: return "Please enter text to translate." # If source language is not specified, detect it if source_lang == "Auto-detect": try: detected_lang = detect(text) if detected_lang in LANGDETECT_TO_NLLB_DISTILLED: src_lang_code = LANGDETECT_TO_NLLB_DISTILLED[detected_lang] source_lang_display = f"Auto-detected: {[k for k, v in NLLB_DISTILLED_LANGUAGE_OPTIONS.items() if v == src_lang_code][0]}" else: return f"Detected language '{detected_lang}' is not supported by NLLB Distilled." except: return "Could not detect language. Please select a source language manually." else: if source_lang not in NLLB_DISTILLED_LANGUAGE_OPTIONS: return f"Language '{source_lang}' is not supported by NLLB Distilled." src_lang_code = NLLB_DISTILLED_LANGUAGE_OPTIONS[source_lang] source_lang_display = source_lang if target_lang not in NLLB_DISTILLED_LANGUAGE_OPTIONS: return f"Target language '{target_lang}' is not supported by NLLB Distilled." tgt_lang_code = NLLB_DISTILLED_LANGUAGE_OPTIONS[target_lang] try: # Translate using NLLB Distilled pipeline result = nllb_distilled_pipe(text, src_lang=src_lang_code, tgt_lang=tgt_lang_code) translation = result[0]['translation_text'] return f"Source Language: {source_lang_display}\nTranslation ({target_lang}):\n\n{translation}" except Exception as e: return f"An error occurred during NLLB Distilled translation: {str(e)}" def translate_narsil_nllb(text, source_lang, target_lang): if not text: return "Please enter text to translate." # If source language is not specified, detect it if source_lang == "Auto-detect": try: detected_lang = detect(text) if detected_lang in LANGDETECT_TO_NARSIL_NLLB: src_lang_code = LANGDETECT_TO_NARSIL_NLLB[detected_lang] source_lang_display = f"Auto-detected: {[k for k, v in NARSIL_NLLB_LANGUAGE_OPTIONS.items() if v == src_lang_code][0]}" else: return f"Detected language '{detected_lang}' is not supported by Narsil/NLLB." except: return "Could not detect language. Please select a source language manually." else: if source_lang not in NARSIL_NLLB_LANGUAGE_OPTIONS: return f"Language '{source_lang}' is not supported by Narsil/NLLB." src_lang_code = NARSIL_NLLB_LANGUAGE_OPTIONS[source_lang] source_lang_display = source_lang if target_lang not in NARSIL_NLLB_LANGUAGE_OPTIONS: return f"Target language '{target_lang}' is not supported by Narsil/NLLB." tgt_lang_code = NARSIL_NLLB_LANGUAGE_OPTIONS[target_lang] try: # Translate using Narsil/NLLB pipeline result = narsil_nllb_pipe(text, src_lang=src_lang_code, tgt_lang=tgt_lang_code) translation = result[0]['translation_text'] return f"Source Language: {source_lang_display}\nTranslation ({target_lang}):\n\n{translation}" except Exception as e: return f"An error occurred during Narsil/NLLB translation: {str(e)}" def translate_all(text, source_lang, target_lang): # Call all translation functions mbart_result = translate_mbart(text, source_lang, target_lang) nllb_distilled_result = translate_nllb_distilled(text, source_lang, target_lang) narsil_nllb_result = translate_narsil_nllb(text, source_lang, target_lang) return mbart_result, nllb_distilled_result, narsil_nllb_result # Get all languages supported by at least one model all_languages = sorted(list(set(MBART_LANGUAGE_OPTIONS.keys()) | set(NLLB_DISTILLED_LANGUAGE_OPTIONS.keys()) | set(NARSIL_NLLB_LANGUAGE_OPTIONS.keys()))) # Create the Gradio interface source_languages = ["Auto-detect"] + all_languages target_languages = all_languages with gr.Blocks(title="Multi-Model Translation") as app: gr.Markdown("# Multilingual Translation System") gr.Markdown("Enter text to translate. If source language is not specified, it will be auto-detected.") with gr.Row(): with gr.Column(): input_text = gr.Textbox(label="Text to translate", lines=10, placeholder="Enter text here...") with gr.Row(): source_lang = gr.Dropdown( choices=source_languages, value="Auto-detect", label="Source Language" ) target_lang = gr.Dropdown( choices=target_languages, value="English", label="Target Language" ) translate_btn = gr.Button("Translate") with gr.Row(): with gr.Column(): gr.Markdown("### MBart Translation") mbart_output = gr.Textbox(label="MBart Translation Output", lines=10) with gr.Column(): gr.Markdown("### NLLB Distilled Translation") nllb_distilled_output = gr.Textbox(label="NLLB Distilled Translation Output", lines=10) with gr.Column(): gr.Markdown("### Narsil/NLLB Translation") narsil_nllb_output = gr.Textbox(label="Narsil/NLLB Translation Output", lines=10) translate_btn.click( fn=translate_all, inputs=[input_text, source_lang, target_lang], outputs=[mbart_output, nllb_distilled_output, narsil_nllb_output] ) gr.Markdown("### Supported Languages") gr.Markdown("Note: Not all languages are supported by all models. If a language is not supported by a model, it will show an error message.") with gr.Accordion("MBart Supported Languages", open=False): gr.Markdown(", ".join(sorted(MBART_LANGUAGE_OPTIONS.keys()))) with gr.Accordion("NLLB Distilled Supported Languages", open=False): gr.Markdown(", ".join(sorted(NLLB_DISTILLED_LANGUAGE_OPTIONS.keys()))) with gr.Accordion("Narsil/NLLB Supported Languages", open=False): gr.Markdown(", ".join(sorted(NARSIL_NLLB_LANGUAGE_OPTIONS.keys()))) # Launch the app if __name__ == "__main__": app.launch(debug=True)