|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM |
|
from PyPDF2 import PdfReader |
|
|
|
|
|
|
|
models = { |
|
"Text Generator (Bloom)": { |
|
"model": AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m"), |
|
"tokenizer": AutoTokenizer.from_pretrained("bigscience/bloom-560m"), |
|
}, |
|
"PDF Summarizer (T5)": { |
|
"model": AutoModelForSeq2SeqLM.from_pretrained("aaliyaan/t5-small-finetuned-career"), |
|
"tokenizer": AutoTokenizer.from_pretrained("aaliyaan/t5-small-finetuned-career", use_fast=False), |
|
}, |
|
"Broken Answer (T0pp)": { |
|
"model": AutoModelForSeq2SeqLM.from_pretrained("bigscience/T0pp"), |
|
"tokenizer": AutoTokenizer.from_pretrained("bigscience/T0pp", use_fast=False), |
|
}, |
|
} |
|
|
|
|
|
def chat_with_model(model_choice, user_message, chat_history, file=None): |
|
if model_choice == "PDF Summarizer (T5)" and file is not None: |
|
pdf_text = extract_text_from_pdf(file) |
|
user_message += f"\n\nPDF Content:\n{pdf_text}" |
|
|
|
if not user_message.strip(): |
|
return chat_history |
|
|
|
model_info = models[model_choice] |
|
tokenizer = model_info["tokenizer"] |
|
model = model_info["model"] |
|
|
|
|
|
inputs = tokenizer(user_message, return_tensors="pt", padding=True, truncation=True, max_length=512) |
|
|
|
|
|
max_length = 150 |
|
num_beams = 5 |
|
outputs = model.generate( |
|
**inputs, |
|
max_length=max_length, |
|
num_beams=num_beams, |
|
early_stopping=True, |
|
no_repeat_ngram_size=2 |
|
) |
|
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
chat_history.append((user_message, response)) |
|
return chat_history |
|
|
|
|
|
def extract_text_from_pdf(file): |
|
from PyPDF2 import PdfReader |
|
reader = PdfReader(file.name) |
|
text = "\n".join(page.extract_text() for page in reader.pages if page.extract_text()) |
|
return text |
|
|
|
|
|
def create_chat_interface(): |
|
with gr.Blocks(css=""" |
|
.chatbox { |
|
background-color: #f7f7f8; |
|
border-radius: 12px; |
|
padding: 16px; |
|
font-family: 'Segoe UI', Tahoma, sans-serif; |
|
} |
|
.chat-title { |
|
font-size: 24px; |
|
font-weight: bold; |
|
text-align: center; |
|
margin-bottom: 12px; |
|
color: #3a9fd6; |
|
} |
|
""") as interface: |
|
gr.Markdown("<div class='chat-title'>GPT-Style Chat Interface</div>") |
|
|
|
with gr.Row(): |
|
model_choice = gr.Dropdown( |
|
choices=list(models.keys()), |
|
value="Text Generator (Bloom)", |
|
label="Select Model" |
|
) |
|
|
|
chat_history = gr.Chatbot(label="Chat History", elem_classes="chatbox") |
|
|
|
user_message = gr.Textbox( |
|
placeholder="Type your message here...", |
|
show_label=False, |
|
elem_classes="chatbox", |
|
) |
|
|
|
file_input = gr.File(label="Upload PDF", visible=False, file_types=[".pdf"]) |
|
|
|
def toggle_pdf_input(selected_model): |
|
return gr.update(visible=(selected_model == "PDF Summarizer (T5)")) |
|
|
|
model_choice.change(fn=toggle_pdf_input, inputs=model_choice, outputs=file_input) |
|
|
|
send_button = gr.Button("Send") |
|
|
|
|
|
send_button.click( |
|
chat_with_model, |
|
inputs=[model_choice, user_message, chat_history, file_input], |
|
outputs=chat_history, |
|
) |
|
|
|
return interface |
|
|
|
if __name__ == "__main__": |
|
interface = create_chat_interface() |
|
interface.launch(server_name="0.0.0.0", server_port=7860) |
|
|