Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import streamlit as st
|
|
|
2 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
3 |
import torch
|
4 |
-
from PyPDF2 import PdfReader
|
5 |
|
6 |
# Initialize the tokenizer and model from the saved checkpoint
|
7 |
tokenizer = AutoTokenizer.from_pretrained("himmeow/vi-gemma-2b-RAG")
|
@@ -15,71 +15,97 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
15 |
if torch.cuda.is_available():
|
16 |
model.to("cuda")
|
17 |
|
18 |
-
#
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
#
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
# Format the input text
|
49 |
-
input_text = f"{user_query}\n\n### Response:\n"
|
50 |
|
51 |
-
# Encode the input text
|
52 |
-
input_ids = tokenizer(
|
53 |
|
54 |
# Use GPU for input ids if available
|
55 |
if torch.cuda.is_available():
|
56 |
input_ids = input_ids.to("cuda")
|
57 |
|
58 |
-
# Generate
|
59 |
outputs = model.generate(
|
60 |
**input_ids,
|
61 |
-
max_new_tokens=
|
62 |
-
no_repeat_ngram_size=5,
|
63 |
)
|
64 |
|
65 |
-
# Decode
|
66 |
-
|
67 |
|
68 |
-
#
|
69 |
-
st.
|
70 |
-
st.write(f"**A{len(st.session_state) + 1}: {answer.strip()}**")
|
71 |
-
|
72 |
-
# Store in session state for chat history
|
73 |
-
if "history" not in st.session_state:
|
74 |
-
st.session_state.history = []
|
75 |
-
|
76 |
-
st.session_state.history.append({
|
77 |
-
"question": user_query,
|
78 |
-
"answer": answer.strip()
|
79 |
-
})
|
80 |
|
81 |
# Display chat history
|
82 |
-
if
|
83 |
-
for i,
|
84 |
-
st.
|
85 |
-
st.
|
|
|
|
1 |
import streamlit as st
|
2 |
+
from PyPDF2 import PdfReader
|
3 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
4 |
import torch
|
|
|
5 |
|
6 |
# Initialize the tokenizer and model from the saved checkpoint
|
7 |
tokenizer = AutoTokenizer.from_pretrained("himmeow/vi-gemma-2b-RAG")
|
|
|
15 |
if torch.cuda.is_available():
|
16 |
model.to("cuda")
|
17 |
|
18 |
+
# Set up the Streamlit app layout
|
19 |
+
st.set_page_config(page_title="RAG PDF Chatbot", layout="wide")
|
20 |
+
|
21 |
+
# Sidebar with file upload and app title with creator details
|
22 |
+
st.sidebar.title("π PDF Upload")
|
23 |
+
uploaded_files = st.sidebar.file_uploader("Upload PDF files", type=["pdf"], accept_multiple_files=True)
|
24 |
+
|
25 |
+
# Multicolor sidebar background
|
26 |
+
st.sidebar.markdown("""
|
27 |
+
<style>
|
28 |
+
.sidebar .sidebar-content {
|
29 |
+
background: linear-gradient(135deg, #ff9a9e, #fad0c4 40%, #fad0c4 60%, #ff9a9e);
|
30 |
+
color: white;
|
31 |
+
}
|
32 |
+
</style>
|
33 |
+
""", unsafe_allow_html=True)
|
34 |
+
|
35 |
+
st.sidebar.markdown("""
|
36 |
+
### Created by: [Engr. Hamesh Raj](https://www.linkedin.com/in/datascientisthameshraj/)
|
37 |
+
""")
|
38 |
+
|
39 |
+
# Main title
|
40 |
+
st.markdown("""
|
41 |
+
<h1 style='text-align: center; color: #ff6f61;'>π RAG PDF Chatbot</h1>
|
42 |
+
""", unsafe_allow_html=True)
|
43 |
+
|
44 |
+
# Multicolor background for the main content
|
45 |
+
st.markdown("""
|
46 |
+
<style>
|
47 |
+
body {
|
48 |
+
background: linear-gradient(135deg, #89f7fe 0%, #66a6ff 100%);
|
49 |
+
}
|
50 |
+
</style>
|
51 |
+
""", unsafe_allow_html=True)
|
52 |
+
|
53 |
+
# Input field for user queries
|
54 |
+
query = st.text_input("Enter your query here:")
|
55 |
+
submit_button = st.button("Submit")
|
56 |
+
|
57 |
+
# Initialize chat history
|
58 |
+
if 'chat_history' not in st.session_state:
|
59 |
+
st.session_state.chat_history = []
|
60 |
+
|
61 |
+
# Function to extract text from PDF files
|
62 |
+
def extract_text_from_pdfs(files):
|
63 |
+
text = ""
|
64 |
+
for uploaded_file in files:
|
65 |
+
reader = PdfReader(uploaded_file)
|
66 |
+
for page in reader.pages:
|
67 |
+
text += page.extract_text() + "\n"
|
68 |
+
return text
|
69 |
+
|
70 |
+
# Handle the query submission
|
71 |
+
if submit_button and query:
|
72 |
+
# Extract text from uploaded PDFs
|
73 |
+
if uploaded_files:
|
74 |
+
pdf_text = extract_text_from_pdfs(uploaded_files)
|
75 |
+
|
76 |
+
# Prepare the input prompt
|
77 |
+
prompt = f"""
|
78 |
+
### Instruction and Input:
|
79 |
+
Based on the following context/document:
|
80 |
+
{pdf_text}
|
81 |
+
Please answer the question: {query}
|
82 |
|
83 |
+
### Response:
|
84 |
+
"""
|
|
|
|
|
85 |
|
86 |
+
# Encode the input text
|
87 |
+
input_ids = tokenizer(prompt, return_tensors="pt")
|
88 |
|
89 |
# Use GPU for input ids if available
|
90 |
if torch.cuda.is_available():
|
91 |
input_ids = input_ids.to("cuda")
|
92 |
|
93 |
+
# Generate the response
|
94 |
outputs = model.generate(
|
95 |
**input_ids,
|
96 |
+
max_new_tokens=500,
|
97 |
+
no_repeat_ngram_size=5,
|
98 |
)
|
99 |
|
100 |
+
# Decode the response
|
101 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
102 |
|
103 |
+
# Update chat history
|
104 |
+
st.session_state.chat_history.append((query, response))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
# Display chat history
|
107 |
+
if st.session_state.chat_history:
|
108 |
+
for i, (q, a) in enumerate(st.session_state.chat_history):
|
109 |
+
st.markdown(f"**Question {i + 1}:** {q}")
|
110 |
+
st.markdown(f"**Answer:** {a}")
|
111 |
+
st.write("---")
|