File size: 4,596 Bytes
e29acd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering

# Load Translation Model
translation_model_name = "VietAI/envit5-translation"
translation_tokenizer = AutoTokenizer.from_pretrained(translation_model_name)
translation_model = AutoModelForSeq2SeqLM.from_pretrained(translation_model_name)

# Translation Function
def translate_text(text, source_lang, target_lang):
    prompt = f"Translate the following text from {source_lang} to {target_lang}: {text}"
    inputs = translation_tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)

    with torch.no_grad():
        output = translation_model.generate(**inputs, max_length=256)

    return translation_tokenizer.decode(output[0], skip_special_tokens=True)

# Load Question Answering Model
qa_model_name = "atharvamundada99/bert-large-question-answering-finetuned-legal"
qa_tokenizer = AutoTokenizer.from_pretrained(qa_model_name)
qa_model = AutoModelForQuestionAnswering.from_pretrained(qa_model_name)

# Question Answering Function
def answer_question(question, context):
    inputs = qa_tokenizer(question, context, return_tensors="pt", truncation=True)

    with torch.no_grad():
        outputs = qa_model(**inputs)

    answer_start = torch.argmax(outputs.start_logits)
    answer_end = torch.argmax(outputs.end_logits) + 1
    answer = qa_tokenizer.convert_tokens_to_string(
        qa_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end])
    )

    return answer if answer.strip() else "Sorry, I couldn't find a relevant answer."

# Load Summarization Model
summarization_model_name = "Falconsai/medical_summarization"
summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model_name)
summarization_model = AutoModelForSeq2SeqLM.from_pretrained(summarization_model_name)

# Summarization Function
def summarize_text(text):
    inputs = summarization_tokenizer(text, return_tensors="pt", max_length=1024, truncation=True)
    with torch.no_grad():
        summary_ids = summarization_model.generate(**inputs, max_length=150, min_length=50, length_penalty=2.0, num_beams=4)

    return summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)

# Function to toggle UI visibility based on selected task
def select_task(task):
    return (
        gr.update(visible=(task == "Translation")),
        gr.update(visible=(task == "Question Answering")),
        gr.update(visible=(task == "Summarization")),
    )

# Function to clear inputs and outputs
def clear_fields():
    return "", "", "", ""

def clear_fields_summary():
    return ""

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## AI-Powered Language Processing")

    task_buttons = gr.Radio(["Translation", "Question Answering", "Summarization"], label="Choose a task")

    with gr.Group(visible=False) as translation_ui:
        source_lang = gr.Textbox(label="Source Language")
        target_lang = gr.Textbox(label="Target Language")
        text_input = gr.Textbox(label="Enter Text")
        translate_button = gr.Button("Translate")
        translation_output = gr.Textbox(label="Translated Text")

        clear_button_t = gr.Button("Clear")
        clear_button_t.click(clear_fields, inputs=[], outputs=[source_lang, target_lang, text_input, translation_output])

        translate_button.click(translate_text, inputs=[text_input, source_lang, target_lang], outputs=translation_output)

    with gr.Group(visible=False) as qa_ui:
        question_input = gr.Textbox(label="Enter Question")
        context_input = gr.Textbox(label="Enter Context")
        answer_button = gr.Button("Get Answer")
        qa_output = gr.Textbox(label="Answer")

        clear_button_qa = gr.Button("Clear")
        clear_button_qa.click(clear_fields, inputs=[], outputs=[question_input, context_input, qa_output])

        answer_button.click(answer_question, inputs=[question_input, context_input], outputs=qa_output)

    with gr.Group(visible=False) as summarization_ui:
        text_input_summary = gr.Textbox(label="Enter Text")
        summarize_button = gr.Button("Summarize")
        summary_output = gr.Textbox(label="Summary")

        clear_button_s = gr.Button("Clear")
        clear_button_s.click(clear_fields_summary, inputs=[], outputs=[text_input_summary, summary_output])

        summarize_button.click(summarize_text, inputs=[text_input_summary], outputs=summary_output)

    task_buttons.change(select_task, inputs=[task_buttons], outputs=[translation_ui, qa_ui, summarization_ui])

demo.launch(share=True)