Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from PyPDF2 import PdfReader
|
3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
4 |
+
import torch
|
5 |
+
|
6 |
+
# App configuration
|
7 |
+
st.set_page_config(page_title="PDF Chatbot", layout="wide")
|
8 |
+
st.markdown(
|
9 |
+
"""
|
10 |
+
<style>
|
11 |
+
body {
|
12 |
+
background: linear-gradient(90deg, rgba(255,224,230,1) 0%, rgba(224,255,255,1) 50%, rgba(224,240,255,1) 100%);
|
13 |
+
color: #000;
|
14 |
+
}
|
15 |
+
</style>
|
16 |
+
""",
|
17 |
+
unsafe_allow_html=True
|
18 |
+
)
|
19 |
+
|
20 |
+
# Title and "Created by" section
|
21 |
+
st.markdown("<h1 style='text-align: center; color: #FF69B4;'>📄 PDF RAG Chatbot</h1>", unsafe_allow_html=True)
|
22 |
+
st.markdown(
|
23 |
+
"<h4 style='text-align: center;'>Created by: <a href='https://www.linkedin.com/in/datascientisthameshraj/' style='color:#FF4500;'>Engr. Hamesh Raj</a></h4>",
|
24 |
+
unsafe_allow_html=True
|
25 |
+
)
|
26 |
+
|
27 |
+
# Sidebar for PDF file upload
|
28 |
+
uploaded_files = st.sidebar.file_uploader("Upload PDF files", type="pdf", accept_multiple_files=True)
|
29 |
+
|
30 |
+
# Query input
|
31 |
+
query = st.text_input("Ask a question about the uploaded PDFs:")
|
32 |
+
|
33 |
+
# Initialize session state to store conversation
|
34 |
+
if "conversation" not in st.session_state:
|
35 |
+
st.session_state.conversation = []
|
36 |
+
|
37 |
+
# Function to extract text from PDFs
|
38 |
+
def extract_text_from_pdfs(files):
|
39 |
+
pdf_text = ""
|
40 |
+
for file in files:
|
41 |
+
reader = PdfReader(file)
|
42 |
+
for page_num in range(len(reader.pages)):
|
43 |
+
page = reader.pages[page_num]
|
44 |
+
pdf_text += page.extract_text() + "\n"
|
45 |
+
return pdf_text
|
46 |
+
|
47 |
+
# Load model and tokenizer
|
48 |
+
@st.cache_resource(allow_output_mutation=True)
|
49 |
+
def load_model():
|
50 |
+
tokenizer = AutoTokenizer.from_pretrained("himmeow/vi-gemma-2b-RAG")
|
51 |
+
model = AutoModelForCausalLM.from_pretrained(
|
52 |
+
"himmeow/vi-gemma-2b-RAG",
|
53 |
+
device_map="auto",
|
54 |
+
torch_dtype=torch.bfloat16
|
55 |
+
)
|
56 |
+
if torch.cuda.is_available():
|
57 |
+
model.to("cuda")
|
58 |
+
return tokenizer, model
|
59 |
+
|
60 |
+
# Process and respond to user query
|
61 |
+
if st.button("Submit"):
|
62 |
+
if uploaded_files and query:
|
63 |
+
pdf_text = extract_text_from_pdfs(uploaded_files)
|
64 |
+
tokenizer, model = load_model()
|
65 |
+
|
66 |
+
prompt = """
|
67 |
+
### Instruction and Input:
|
68 |
+
Based on the following context/document:
|
69 |
+
{}
|
70 |
+
Please answer the question: {}
|
71 |
+
|
72 |
+
### Response:
|
73 |
+
{}
|
74 |
+
"""
|
75 |
+
|
76 |
+
input_text = prompt.format(pdf_text, query, " ")
|
77 |
+
input_ids = tokenizer(input_text, return_tensors="pt")
|
78 |
+
|
79 |
+
if torch.cuda.is_available():
|
80 |
+
input_ids = input_ids.to("cuda")
|
81 |
+
|
82 |
+
outputs = model.generate(
|
83 |
+
**input_ids,
|
84 |
+
max_new_tokens=500,
|
85 |
+
no_repeat_ngram_size=5,
|
86 |
+
)
|
87 |
+
|
88 |
+
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
89 |
+
|
90 |
+
# Store the conversation
|
91 |
+
st.session_state.conversation.insert(0, {"question": query, "answer": answer})
|
92 |
+
|
93 |
+
# Display conversation
|
94 |
+
if st.session_state.conversation:
|
95 |
+
st.markdown("## Previous Conversations")
|
96 |
+
for qa in st.session_state.conversation:
|
97 |
+
st.markdown(f"**Q: {qa['question']}**")
|
98 |
+
st.markdown(f"**A: {qa['answer']}**")
|
99 |
+
st.markdown("---")
|