Spaces:
Running
Running
import torch | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering | |
# Load Translation Model | |
translation_model_name = "VietAI/envit5-translation" | |
translation_tokenizer = AutoTokenizer.from_pretrained(translation_model_name) | |
translation_model = AutoModelForSeq2SeqLM.from_pretrained(translation_model_name) | |
# Translation Function | |
def translate_text(text, source_lang, target_lang): | |
prompt = f"Translate the following text from {source_lang} to {target_lang}: {text}" | |
inputs = translation_tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) | |
with torch.no_grad(): | |
output = translation_model.generate(**inputs, max_length=256) | |
return translation_tokenizer.decode(output[0], skip_special_tokens=True) | |
# Load Question Answering Model | |
qa_model_name = "atharvamundada99/bert-large-question-answering-finetuned-legal" | |
qa_tokenizer = AutoTokenizer.from_pretrained(qa_model_name) | |
qa_model = AutoModelForQuestionAnswering.from_pretrained(qa_model_name) | |
# Question Answering Function | |
def answer_question(question, context): | |
inputs = qa_tokenizer(question, context, return_tensors="pt", truncation=True) | |
with torch.no_grad(): | |
outputs = qa_model(**inputs) | |
answer_start = torch.argmax(outputs.start_logits) | |
answer_end = torch.argmax(outputs.end_logits) + 1 | |
answer = qa_tokenizer.convert_tokens_to_string( | |
qa_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end]) | |
) | |
return answer if answer.strip() else "Sorry, I couldn't find a relevant answer." | |
# Load Summarization Model | |
summarization_model_name = "Falconsai/medical_summarization" | |
summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model_name) | |
summarization_model = AutoModelForSeq2SeqLM.from_pretrained(summarization_model_name) | |
# Summarization Function | |
def summarize_text(text): | |
inputs = summarization_tokenizer(text, return_tensors="pt", max_length=1024, truncation=True) | |
with torch.no_grad(): | |
summary_ids = summarization_model.generate(**inputs, max_length=150, min_length=50, length_penalty=2.0, num_beams=4) | |
return summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
# Function to toggle UI visibility based on selected task | |
def select_task(task): | |
return ( | |
gr.update(visible=(task == "Translation")), | |
gr.update(visible=(task == "Question Answering")), | |
gr.update(visible=(task == "Summarization")), | |
) | |
# Function to clear inputs and outputs | |
def clear_fields(): | |
return "", "", "", "" | |
def clear_fields_summary(): | |
return "" | |
# Gradio UI | |
with gr.Blocks() as demo: | |
gr.Markdown("## AI-Powered Language Processing") | |
task_buttons = gr.Radio(["Translation", "Question Answering", "Summarization"], label="Choose a task") | |
with gr.Group(visible=False) as translation_ui: | |
source_lang = gr.Textbox(label="Source Language") | |
target_lang = gr.Textbox(label="Target Language") | |
text_input = gr.Textbox(label="Enter Text") | |
translate_button = gr.Button("Translate") | |
translation_output = gr.Textbox(label="Translated Text") | |
clear_button_t = gr.Button("Clear") | |
clear_button_t.click(clear_fields, inputs=[], outputs=[source_lang, target_lang, text_input, translation_output]) | |
translate_button.click(translate_text, inputs=[text_input, source_lang, target_lang], outputs=translation_output) | |
with gr.Group(visible=False) as qa_ui: | |
question_input = gr.Textbox(label="Enter Question") | |
context_input = gr.Textbox(label="Enter Context") | |
answer_button = gr.Button("Get Answer") | |
qa_output = gr.Textbox(label="Answer") | |
clear_button_qa = gr.Button("Clear") | |
clear_button_qa.click(clear_fields, inputs=[], outputs=[question_input, context_input, qa_output]) | |
answer_button.click(answer_question, inputs=[question_input, context_input], outputs=qa_output) | |
with gr.Group(visible=False) as summarization_ui: | |
text_input_summary = gr.Textbox(label="Enter Text") | |
summarize_button = gr.Button("Summarize") | |
summary_output = gr.Textbox(label="Summary") | |
clear_button_s = gr.Button("Clear") | |
clear_button_s.click(clear_fields_summary, inputs=[], outputs=[text_input_summary, summary_output]) | |
summarize_button.click(summarize_text, inputs=[text_input_summary], outputs=summary_output) | |
task_buttons.change(select_task, inputs=[task_buttons], outputs=[translation_ui, qa_ui, summarization_ui]) | |
demo.launch(share=True) |