Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import torch | |
import spaces | |
import gradio as gr | |
from threading import Thread | |
from collections.abc import Iterator | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
MAX_MAX_NEW_TOKENS = 4096 | |
MAX_INPUT_TOKEN_LENGTH = 4096 | |
DEFAULT_MAX_NEW_TOKENS = 2048 | |
HF_TOKEN = os.environ['HF_TOKEN'] | |
model_id = "ai4bharat/IndicTrans3-beta" | |
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto", token=HF_TOKEN) | |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct") | |
LANGUAGES = { | |
"Hindi": "hin_Deva", | |
"Bengali": "ben_Beng", | |
"Telugu": "tel_Telu", | |
"Marathi": "mar_Deva", | |
"Tamil": "tam_Taml", | |
"Urdu": "urd_Arab", | |
"Gujarati": "guj_Gujr", | |
"Kannada": "kan_Knda", | |
"Odia": "ori_Orya", | |
"Malayalam": "mal_Mlym", | |
"Punjabi": "pan_Guru", | |
"Assamese": "asm_Beng", | |
"Maithili": "mai_Mith", | |
"Santali": "sat_Olck", | |
"Kashmiri": "kas_Arab", | |
"Nepali": "nep_Deva", | |
"Sindhi": "snd_Arab", | |
"Konkani": "kok_Deva", | |
"Dogri": "dgo_Deva", | |
"Manipuri": "mni_Beng", | |
"Bodo": "brx_Deva" | |
} | |
def format_message_for_translation(message, target_lang): | |
return f"Translate the following text to {target_lang}: {message}" | |
def translate_message( | |
message: str, | |
chat_history: list[dict], | |
target_language: str = "Hindi", | |
max_new_tokens: int = 1024, | |
temperature: float = 0.6, | |
top_p: float = 0.9, | |
top_k: int = 50, | |
repetition_penalty: float = 1.2, | |
) -> Iterator[str]: | |
conversation = [] | |
translation_request = format_message_for_translation(message, target_language) | |
print(f"Translation request: {translation_request}") | |
conversation.append({"role": "user", "content": translation_request}) | |
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True) | |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
gr.Warning(f"Trimmed input as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") | |
input_ids = input_ids.to(model.device) | |
streamer = TextIteratorStreamer(tokenizer, timeout=240.0, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = dict( | |
{"input_ids": input_ids}, | |
streamer=streamer, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
num_beams=1, | |
repetition_penalty=repetition_penalty, | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
outputs = [] | |
for text in streamer: | |
outputs.append(text) | |
yield "".join(outputs) | |
def store_feedback(rating, feedback_text): | |
if not rating: | |
gr.Warning("Please select a rating before submitting feedback.", duration=5) | |
return None | |
if not feedback_text or feedback_text.strip() == "": | |
gr.Warning("Please provide some feedback before submitting.", duration=5) | |
return None | |
gr.Info("Feedback submitted successfully!") | |
return "Thank you for your feedback!" | |
css = """ | |
# body { | |
# background-color: #f7f7f7; | |
# } | |
.feedback-section { | |
margin-top: 30px; | |
border-top: 1px solid #ddd; | |
padding-top: 20px; | |
} | |
.container { | |
max-width: 90%; | |
margin: 0 auto; | |
} | |
.language-selector { | |
margin-bottom: 20px; | |
padding: 10px; | |
background-color: #ffffff; | |
border-radius: 8px; | |
box-shadow: 0 2px 5px rgba(0,0,0,0.1); | |
} | |
.advanced-options { | |
margin-top: 20px; | |
} | |
""" | |
DESCRIPTION = """\ | |
IndicTrans3 is the latest state-of-the-art (SOTA) translation model from AI4Bharat, designed to handle translations across <b>22 Indic languages</b> with high accuracy. It supports <b>document-level machine translation (MT)</b> and is built to match the performance of other leading SOTA models. <br> | |
π’ <b>Training data will be released soon!</b> | |
<h3>πΉ Features</h3> | |
β Supports <b>22 Indic languages</b> | |
β Enables <b>document-level translation</b> | |
β Achieves <b>SOTA performance</b> in Indic MT | |
β Optimized for <b>real-world applications</b> | |
<h3>π Try It Out!</h3> | |
1οΈβ£ Enter text in any supported language | |
2οΈβ£ Select the target language | |
3οΈβ£ Click <b>Translate</b> and get high-quality results! | |
Built for <b>linguistic diversity and accessibility</b>, IndicTrans3 is a major step forward in <b>Indic language AI</b>. | |
π‘ <b>Source:</b> AI4Bharat | Powered by Hugging Face | |
""" | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(elem_classes="container"): | |
gr.Markdown("# π IndicTrans3-beta π: Multilingual Translation for 22 Indic Languages </center>") | |
gr.Markdown(DESCRIPTION) | |
target_language = gr.Dropdown( | |
list(LANGUAGES.keys()), | |
value="Hindi", | |
label="Which language would you like to translate to?", | |
elem_id="language-dropdown" | |
) | |
chatbot = gr.Chatbot(height=400, elem_id="chatbot") | |
with gr.Row(): | |
msg = gr.Textbox( | |
placeholder="Enter text to translate...", | |
show_label=False, | |
container=False, | |
scale=9 | |
) | |
submit_btn = gr.Button("Translate", scale=1) | |
gr.Examples( | |
examples=[ | |
"The Taj Mahal stands majestically along the banks of river Yamuna, a timeless symbol of eternal love.", | |
"Kumbh Mela is the world's largest gathering of people, where millions of pilgrims bathe in sacred rivers for spiritual purification.", | |
"India's classical dance forms like Bharatanatyam, Kathak, and Odissi beautifully blend rhythm, expression, and storytelling.", | |
"Ayurveda, the ancient Indian medical system, focuses on holistic wellness through natural herbs and balanced living.", | |
"During Diwali, homes across India are decorated with oil lamps, colorful rangoli patterns, and twinkling lights to celebrate the victory of light over darkness." | |
], | |
inputs=msg | |
) | |
with gr.Accordion("Provide Feedback", open=True): | |
gr.Markdown("## Rate Translation & Provide Feedback π") | |
gr.Markdown("Help us improve the translation quality by providing your feedback.") | |
with gr.Row(): | |
rating = gr.Radio( | |
["1", "2", "3", "4", "5"], | |
label="Translation Rating (1-5)" | |
) | |
feedback_text = gr.Textbox( | |
placeholder="Share your feedback about the translation...", | |
label="Feedback", | |
lines=3 | |
) | |
feedback_submit = gr.Button("Submit Feedback") | |
feedback_result = gr.Textbox(label="", visible=False) | |
with gr.Accordion("Advanced Options", open=False, elem_classes="advanced-options"): | |
max_new_tokens = gr.Slider( | |
label="Max new tokens", | |
minimum=1, | |
maximum=MAX_MAX_NEW_TOKENS, | |
step=1, | |
value=DEFAULT_MAX_NEW_TOKENS, | |
) | |
temperature = gr.Slider( | |
label="Temperature", | |
minimum=0.1, | |
maximum=1.0, | |
step=0.1, | |
value=0.1, | |
) | |
top_p = gr.Slider( | |
label="Top-p (nucleus sampling)", | |
minimum=0.05, | |
maximum=1.0, | |
step=0.05, | |
value=0.9, | |
) | |
top_k = gr.Slider( | |
label="Top-k", | |
minimum=1, | |
maximum=100, | |
step=1, | |
value=50, | |
) | |
repetition_penalty = gr.Slider( | |
label="Repetition penalty", | |
minimum=1.0, | |
maximum=2.0, | |
step=0.05, | |
value=1.0, | |
) | |
chat_state = gr.State([]) | |
def user(user_message, history, target_lang): | |
return "", history + [[user_message, None]] | |
def bot(history, target_lang, max_tokens, temp, top_p_val, top_k_val, rep_penalty): | |
user_message = history[-1][0] | |
history[-1][1] = "" | |
for chunk in translate_message( | |
user_message, | |
history[:-1], | |
target_lang, | |
max_tokens, | |
temp, | |
top_p_val, | |
top_k_val, | |
rep_penalty | |
): | |
history[-1][1] = chunk | |
yield history | |
msg.submit( | |
user, | |
[msg, chatbot, target_language], | |
[msg, chatbot], | |
queue=False | |
).then( | |
bot, | |
[chatbot, target_language, max_new_tokens, temperature, top_p, top_k, repetition_penalty], | |
chatbot | |
) | |
submit_btn.click( | |
user, | |
[msg, chatbot, target_language], | |
[msg, chatbot], | |
queue=False | |
).then( | |
bot, | |
[chatbot, target_language, max_new_tokens, temperature, top_p, top_k, repetition_penalty], | |
chatbot | |
) | |
feedback_submit.click( | |
fn=store_feedback, | |
inputs=[rating, feedback_text], | |
outputs=feedback_result | |
) | |
if __name__ == "__main__": | |
demo.launch() | |