Sshubam's picture
Update app.py
b9f8a9e verified
raw
history blame
7.12 kB
import torch
# import spaces
from collections.abc import Iterator
from threading import Thread
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
MAX_MAX_NEW_TOKENS = 4096
DEFAULT_MAX_NEW_TOKENS = 2048
MAX_INPUT_TOKEN_LENGTH = 4096
DESCRIPTION = """\
# IndicTrans3-beta πŸš€
"""
# if not torch.cuda.is_available():
# DESCRIPTION += "\n<p>Running on CPU πŸ₯Ά This demo does not work on CPU.</p>"
# if torch.cuda.is_available():
model_id = "ai4bharat/IndicTrans3-beta"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto", offload_folder="offload")
tokenizer = AutoTokenizer.from_pretrained(model_id)
LANGUAGES = {
"Hindi": "hin_Deva",
"Bengali": "ben_Beng",
"Telugu": "tel_Telu",
"Marathi": "mar_Deva",
"Tamil": "tam_Taml",
"Urdu": "urd_Arab",
"Gujarati": "guj_Gujr",
"Kannada": "kan_Knda",
"Odia": "ori_Orya",
"Malayalam": "mal_Mlym",
"Punjabi": "pan_Guru",
"Assamese": "asm_Beng",
"Maithili": "mai_Mith",
"Santali": "sat_Olck",
"Kashmiri": "kas_Arab",
"Nepali": "nep_Deva",
"Sindhi": "snd_Arab",
"Konkani": "kok_Deva",
"Dogri": "dgo_Deva",
"Manipuri": "mni_Beng",
"Bodo": "brx_Deva"
}
@spaces.GPU
def generate(
tgt_lang: str,
message: str,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
conversation = []
conversation.append({"role": "user", "content": f"Translate the following text to {tgt_lang}: {message}"})
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
def store_feedback(rating, feedback_text):
if not rating:
gr.Warning("Please select a rating before submitting feedback.", duration=5)
return None
if not feedback_text or feedback_text.strip() == "":
gr.Warning("Please provide some feedback before submitting.", duration=5)
return None
gr.Info("Feedback submitted successfully!")
return "Thank you for your feedback!"
css = """
#col-container {max-width: 80%; margin-left: auto; margin-right: auto;}
#header {text-align: center;}
.message { font-size: 1.2em; }
#feedback-section { margin-top: 30px; border-top: 1px solid #ddd; padding-top: 20px; }
"""
with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
gr.Markdown(DESCRIPTION, elem_id="header")
gr.Markdown("Translate text between multiple Indic languages using the latest IndicTrans3 model from AI4Bharat. This model is trained on the --- dataset and supports translation to 22 Indic languages. Setting a state-of-the-art benchmark on multiple translation tasks, IndicTrans3 is a powerful model that can handle complex translation tasks with ease.", elem_id="description")
with gr.Column(elem_id="col-container"):
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
placeholder="Enter text to translate...",
label="Input text",
lines=10,
max_lines=100,
elem_id="input-text"
)
with gr.Column():
tgt_lang = gr.Dropdown(
list(LANGUAGES.keys()),
value="Hindi",
label="Translate To",
elem_id="translate-to"
)
text_output = gr.Textbox(
label="",
lines=10,
max_lines=100,
elem_id="output-text"
)
btn_submit = gr.Button("Translate")
btn_submit.click(
fn=generate,
inputs=[
tgt_lang,
text_input,
gr.Number(value=4096, visible=False),
gr.Number(value=0, visible=False),
gr.Number(value=0.9, visible=False),
gr.Number(value=50, visible=False),
gr.Number(value=0, visible=False)
],
outputs=text_output
)
gr.Examples(
examples=[
["English", "Hello, how are you today? I hope you're doing well.", "Telugu"],
["English", "Hello, how are you today? I hope you're doing well.", "Punjabi"],
["English", "Hello, how are you today? I hope you're doing well.", "Hindi"],
["English", "Hello, how are you today? I hope you're doing well.", "Marathi"],
["English", "Hello, how are you today? I hope you're doing well.", "Malayalam"]
],
inputs=[
tgt_lang,
text_input,
gr.Number(value=4096, visible=False),
gr.Number(value=0, visible=False),
gr.Number(value=0.9, visible=False),
gr.Number(value=50, visible=False),
gr.Number(value=0, visible=False)
],
outputs=text_output,
fn=generate,
cache_examples=True,
examples_per_page=5
)
with gr.Column(elem_id="feedback-section"):
gr.Markdown("## Rate Translation & Provide Feedback πŸ“")
gr.Markdown("Help us improve the translation quality by providing your feedback and rating.")
with gr.Row():
rating = gr.Radio(
["1", "2", "3", "4", "5"],
label="Translation Rating (1-5)"
)
feedback_text = gr.Textbox(
placeholder="Share your feedback about the translation...",
label="Feedback",
lines=3
)
feedback_submit = gr.Button("Submit Feedback")
feedback_result = gr.Textbox(label="", visible=False)
feedback_submit.click(
fn=store_feedback,
inputs=[rating, feedback_text],
outputs=feedback_result
)
demo.launch()