zliang commited on
Commit
5599ea4
Β·
verified Β·
1 Parent(s): 7263d31

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +330 -176
app.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  import os
2
  os.system("python -m spacy download en_core_web_sm")
3
  import io
@@ -13,30 +17,21 @@ from langchain_core.output_parsers import StrOutputParser
13
  from langchain_community.document_loaders import PyMuPDFLoader
14
  from langchain_openai import OpenAIEmbeddings
15
  from langchain_text_splitters import RecursiveCharacterTextSplitter
 
16
  from langchain_core.prompts import ChatPromptTemplate
17
  from langchain_openai import ChatOpenAI
18
  import re
19
  from PIL import Image
 
20
 
21
- openai_api_key = os.environ.get("openai_api_key")
22
- # Cached resources
23
- @st.cache_resource
24
- def load_models():
25
- return {
26
- "yolo": YOLO("best.pt"),
27
- "embeddings": OpenAIEmbeddings(model="text-embedding-3-small",api_key=openai_api_key),
28
- "llm": ChatOpenAI(model="gpt-4-turbo", temperature=0.3,api_key=openai_api_key)
29
- }
30
-
31
- models = load_models()
32
 
 
 
33
 
34
- # Constants
35
- FIGURE_CLASS_INDEX = 4
36
- TABLE_CLASS_INDEX = 3
37
- CHUNK_SIZE = 1000
38
- CHUNK_OVERLAP = 200
39
- NUM_CLUSTERS = 8
40
 
41
  # Utility functions
42
  def clean_text(text):
@@ -44,182 +39,341 @@ def clean_text(text):
44
 
45
  def remove_references(text):
46
  reference_patterns = [
47
- r'\bReferences\b', r'\breferences\b', r'\bBibliography\b',
48
- r'\bCitations\b', r'\bWorks Cited\b'
49
  ]
50
- return re.sub('|'.join(reference_patterns), '', text, flags=re.IGNORECASE)
 
 
 
 
 
 
 
 
 
 
51
 
52
- @st.cache_data
53
- def process_pdf(file_path):
54
- """Process PDF once and cache results"""
55
- loader = PyMuPDFLoader(file_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  docs = loader.load()
57
  full_text = "\n".join(doc.page_content for doc in docs)
58
- cleaned_text = clean_text(remove_references(full_text))
59
-
60
- text_splitter = RecursiveCharacterTextSplitter(
61
- chunk_size=CHUNK_SIZE,
62
- chunk_overlap=CHUNK_OVERLAP,
63
- separators=["\n\n", "\n", ". ", "! ", "? ", " "]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  )
65
- split_contents = text_splitter.split_text(cleaned_text)
66
-
67
- return {
68
- "text": cleaned_text,
69
- "chunks": split_contents,
70
- "embeddings": models["embeddings"].embed_documents(split_contents)
71
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- @st.cache_data
74
- def extract_visuals(file_path):
75
- """Extract figures and tables with caching"""
76
- doc = fitz.open(file_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  all_figures = []
78
  all_tables = []
 
 
 
 
79
 
80
- for page in doc:
81
- low_res_pix = page.get_pixmap(dpi=50)
82
  low_res_img = np.frombuffer(low_res_pix.samples, dtype=np.uint8).reshape(low_res_pix.height, low_res_pix.width, 3)
83
-
84
- results = models["yolo"].predict(low_res_img)
85
- boxes = [
86
- (int(box.xyxy[0][0]), int(box.xyxy[0][1]),
87
- int(box.xyxy[0][2]), int(box.xyxy[0][3]), int(box.cls[0]))
88
- for result in results for box in result.boxes
89
- if box.conf[0] > 0.8 and int(box.cls[0]) in {FIGURE_CLASS_INDEX, TABLE_CLASS_INDEX}
90
- ]
91
 
92
  if boxes:
93
- high_res_pix = page.get_pixmap(dpi=300)
94
  high_res_img = np.frombuffer(high_res_pix.samples, dtype=np.uint8).reshape(high_res_pix.height, high_res_pix.width, 3)
95
-
96
- for x1, y1, x2, y2, cls in boxes:
97
- img = high_res_img[int(y1*6):int(y2*6), int(x1*6):int(x2*6)]
98
- if cls == FIGURE_CLASS_INDEX:
99
- all_figures.append(img)
100
- else:
101
- all_tables.append(img)
102
-
103
- return {"figures": all_figures, "tables": all_tables}
104
-
105
- def generate_summary(chunks, embeddings):
106
- """Generate summary using clustered chunks"""
107
- kmeans = KMeans(n_clusters=NUM_CLUSTERS, init='k-means++').fit(embeddings)
108
- cluster_indices = [np.argmin(np.linalg.norm(embeddings - center, axis=1))
109
- for center in kmeans.cluster_centers_]
110
- selected_chunks = [chunks[i] for i in cluster_indices]
111
-
112
- prompt = ChatPromptTemplate.from_template(
113
- """Create a structured summary with key points from these context sections:
114
- {contexts}
115
- Format:
116
- ## Summary
117
- [concise overview]
118
- ## Key Points
119
- - [main point 1]
120
- - [main point 2]
121
- ..."""
122
- )
123
- chain = prompt | models["llm"] | StrOutputParser()
124
- return chain.invoke({"contexts": '\n\n'.join(selected_chunks)})
125
-
126
- def answer_question(question, chunks, embeddings):
127
- """Answer question using semantic search"""
128
- query_embedding = models["embeddings"].embed_query(question)
129
- similarities = cosine_similarity([query_embedding], embeddings)[0]
130
- top_indices = np.argsort(similarities)[-5:][::-1]
131
- context = '\n'.join([chunks[i] for i in top_indices if similarities[i] > 0.6])
132
 
133
- prompt = ChatPromptTemplate.from_template(
134
- """Answer this question: {question}
135
- Using only this context: {context}
136
- - Be precise and include relevant details
137
- - Cite sources as [Source 1], [Source 2], etc."""
138
- )
139
- chain = prompt | models["llm"] | StrOutputParser()
140
- return chain.invoke({"question": question, "context": context})
141
-
142
- # Streamlit UI
143
- #st.set_page_config(page_title="PDF Assistant", layout="wide")
144
- st.title("πŸ“„ Smart PDF Assistant")
145
-
146
- if "chat" not in st.session_state:
147
- st.session_state.chat = []
148
- if "processed_data" not in st.session_state:
149
- st.session_state.processed_data = None
150
-
151
- # File upload section
152
- with st.sidebar:
153
- uploaded_file = st.file_uploader("Upload PDF", type="pdf")
154
- if uploaded_file:
155
- with tempfile.NamedTemporaryFile(delete=False) as tmp:
156
- tmp.write(uploaded_file.getbuffer())
157
- st.session_state.processed_data = process_pdf(tmp.name)
158
- visuals = extract_visuals(tmp.name)
159
-
160
- # Chat interface
161
- col1, col2 = st.columns([3, 1])
162
- with col1:
163
- st.subheader("Document Interaction")
164
-
165
- for msg in st.session_state.chat:
166
- with st.chat_message(msg["role"]):
167
- if "image" in msg:
168
- st.image(msg["image"], caption=msg.get("caption"))
169
- else:
170
- st.markdown(msg["content"])
171
-
172
- if prompt := st.chat_input("Ask about the document..."):
173
- st.session_state.chat.append({"role": "user", "content": prompt})
174
- with st.spinner("Analyzing..."):
175
- response = answer_question(
176
- prompt,
177
- st.session_state.processed_data["chunks"],
178
- st.session_state.processed_data["embeddings"]
179
- )
180
- st.session_state.chat.append({"role": "assistant", "content": response})
181
- st.rerun()
182
-
183
- with col2:
184
- st.subheader("Document Insights")
185
-
186
- if st.button("Generate Summary"):
187
- with st.spinner("Summarizing..."):
188
- summary = generate_summary(
189
- st.session_state.processed_data["chunks"],
190
- st.session_state.processed_data["embeddings"]
191
- )
192
- st.session_state.chat.append({
193
- "role": "assistant",
194
- "content": f"## Document Summary\n{summary}"
195
- })
196
- st.rerun()
197
-
198
- if visuals["figures"]:
199
- with st.expander(f"πŸ“· Figures ({len(visuals['figures'])})"):
200
- for idx, fig in enumerate(visuals["figures"], 1):
201
- st.image(fig, caption=f"Figure {idx}")
202
-
203
- if visuals["tables"]:
204
- with st.expander(f"πŸ“Š Tables ({len(visuals['tables'])})"):
205
- for idx, tbl in enumerate(visuals["tables"], 1):
206
- st.image(tbl, caption=f"Table {idx}")
207
 
208
- # Custom styling
209
- st.markdown("""
 
 
 
 
 
 
 
 
 
 
 
210
  <style>
211
- [data-testid=stSidebar] {
212
- background: #fafafa;
213
- border-right: 1px solid #eee;
214
  }
215
- .stChatMessage {
216
- padding: 1rem;
217
- margin: 0.5rem 0;
218
- border-radius: 10px;
219
- box-shadow: 0 2px 5px rgba(0,0,0,0.1);
220
  }
221
- [data-testid=stVerticalBlock] > div:has(>.stChatMessage) {
222
- gap: 0.5rem;
 
 
 
 
223
  }
224
  </style>
225
- """, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+
5
  import os
6
  os.system("python -m spacy download en_core_web_sm")
7
  import io
 
17
  from langchain_community.document_loaders import PyMuPDFLoader
18
  from langchain_openai import OpenAIEmbeddings
19
  from langchain_text_splitters import RecursiveCharacterTextSplitter
20
+ from langchain_text_splitters import SpacyTextSplitter
21
  from langchain_core.prompts import ChatPromptTemplate
22
  from langchain_openai import ChatOpenAI
23
  import re
24
  from PIL import Image
25
+ from streamlit_chat import message
26
 
27
+ # Load the trained model
 
 
 
 
 
 
 
 
 
 
28
 
29
+ model = YOLO("best.pt")
30
+ openai_api_key = os.environ.get("openai_api_key")
31
 
32
+ # Define the class indices for figures, tables, and text
33
+ figure_class_index = 4
34
+ table_class_index = 3
 
 
 
35
 
36
  # Utility functions
37
  def clean_text(text):
 
39
 
40
  def remove_references(text):
41
  reference_patterns = [
42
+ r'\bReferences\b', r'\breferences\b', r'\bBibliography\b', r'\bCitations\b',
43
+ r'\bWorks Cited\b', r'\bReference\b', r'\breference\b'
44
  ]
45
+ lines = text.split('\n')
46
+ for i, line in enumerate(lines):
47
+ if any(re.search(pattern, line, re.IGNORECASE) for pattern in reference_patterns):
48
+ return '\n'.join(lines[:i])
49
+ return text
50
+
51
+ def save_uploaded_file(uploaded_file):
52
+ temp_file = tempfile.NamedTemporaryFile(delete=False)
53
+ temp_file.write(uploaded_file.getbuffer())
54
+ temp_file.close()
55
+ return temp_file.name
56
 
57
+ def summarize_pdf(pdf_file_path, num_clusters=10):
58
+ embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small", api_key=openai_api_key)
59
+ llm = ChatOpenAI(model="gpt-4o-mini", api_key=openai_api_key, temperature=0.3)
60
+ prompt = ChatPromptTemplate.from_template(
61
+ """Could you please provide a concise and comprehensive summary of the given Contexts?
62
+ The summary should capture the main points and key details of the text while conveying the author's intended meaning accurately.
63
+ Please ensure that the summary is well-organized and easy to read, with clear headings and subheadings to guide the reader through each section.
64
+ The length of the summary should be appropriate to capture the main points and key details of the text, without including unnecessary information or becoming overly long.
65
+ example of summary:
66
+ ## Summary:
67
+ ## Key points:
68
+ Contexts: {topic}"""
69
+ )
70
+ output_parser = StrOutputParser()
71
+ chain = prompt | llm | output_parser
72
+
73
+ loader = PyMuPDFLoader(pdf_file_path)
74
  docs = loader.load()
75
  full_text = "\n".join(doc.page_content for doc in docs)
76
+ cleaned_full_text = clean_text(remove_references(full_text))
77
+ text_splitter = SpacyTextSplitter(chunk_size=500)
78
+ #text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0, separators=["\n\n", "\n", ".", " "])
79
+ split_contents = text_splitter.split_text(cleaned_full_text)
80
+ embeddings = embeddings_model.embed_documents(split_contents)
81
+
82
+ kmeans = KMeans(n_clusters=num_clusters, init='k-means++', random_state=0).fit(embeddings)
83
+ closest_point_indices = [np.argmin(np.linalg.norm(embeddings - center, axis=1)) for center in kmeans.cluster_centers_]
84
+ extracted_contents = [split_contents[idx] for idx in closest_point_indices]
85
+
86
+ results = chain.invoke({"topic": ' '.join(extracted_contents)})
87
+
88
+ return generate_citations(results, extracted_contents)
89
+
90
+ def qa_pdf(pdf_file_path, query, num_clusters=5, similarity_threshold=0.6):
91
+ embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small", api_key=openai_api_key)
92
+ llm = ChatOpenAI(model="gpt-4o-mini", api_key=openai_api_key, temperature=0.3)
93
+ prompt = ChatPromptTemplate.from_template(
94
+ """Please provide a detailed and accurate answer to the given question based on the provided contexts.
95
+ Ensure that the answer is comprehensive and directly addresses the query.
96
+ If necessary, include relevant examples or details from the text.
97
+ Question: {question}
98
+ Contexts: {contexts}"""
99
  )
100
+ output_parser = StrOutputParser()
101
+ chain = prompt | llm | output_parser
102
+
103
+ loader = PyMuPDFLoader(pdf_file_path)
104
+ docs = loader.load()
105
+ full_text = "\n".join(doc.page_content for doc in docs)
106
+ cleaned_full_text = clean_text(remove_references(full_text))
107
+ text_splitter = SpacyTextSplitter(chunk_size=500)
108
+
109
+ #text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=0, separators=["\n\n", "\n", ".", " "])
110
+ split_contents = text_splitter.split_text(cleaned_full_text)
111
+ embeddings = embeddings_model.embed_documents(split_contents)
112
+
113
+ query_embedding = embeddings_model.embed_query(query)
114
+ similarity_scores = cosine_similarity([query_embedding], embeddings)[0]
115
+ top_indices = np.argsort(similarity_scores)[-num_clusters:]
116
+ relevant_contents = [split_contents[i] for i in top_indices]
117
+
118
+ results = chain.invoke({"question": query, "contexts": ' '.join(relevant_contents)})
119
+
120
+ return generate_citations(results, relevant_contents, similarity_threshold)
121
+
122
+ def generate_citations(text, contents, similarity_threshold=0.6):
123
+ embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small", api_key=openai_api_key)
124
+ text_sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
125
+ text_embeddings = embeddings_model.embed_documents(text_sentences)
126
+ content_embeddings = embeddings_model.embed_documents(contents)
127
+ similarity_matrix = cosine_similarity(text_embeddings, content_embeddings)
128
+
129
+ cited_text = text
130
+ relevant_sources = []
131
+ source_mapping = {}
132
+ sentence_to_source = {}
133
+
134
+ for i, sentence in enumerate(text_sentences):
135
+ if sentence in sentence_to_source:
136
+ continue
137
+ max_similarity = max(similarity_matrix[i])
138
+ if max_similarity >= similarity_threshold:
139
+ most_similar_idx = np.argmax(similarity_matrix[i])
140
+ if most_similar_idx not in source_mapping:
141
+ source_mapping[most_similar_idx] = len(relevant_sources) + 1
142
+ relevant_sources.append((most_similar_idx, contents[most_similar_idx]))
143
+ citation_idx = source_mapping[most_similar_idx]
144
+ citation = f"([Source {citation_idx}](#source-{citation_idx}))"
145
+ cited_sentence = re.sub(r'([.!?])$', f" {citation}\\1", sentence)
146
+ sentence_to_source[sentence] = citation_idx
147
+ cited_text = cited_text.replace(sentence, cited_sentence)
148
+
149
+ sources_list = "\n\n## Sources:\n"
150
+ for idx, (original_idx, content) in enumerate(relevant_sources):
151
+ sources_list += f"""
152
+ <details style="margin: 1px 0; padding: 5px; border: 1px solid #ccc; border-radius: 8px; background-color: #f9f9f9; transition: all 0.3s ease;">
153
+ <summary style="font-weight: bold; cursor: pointer; outline: none; padding: 5px 0; transition: color 0.3s ease;">Source {idx + 1}</summary>
154
+ <pre style="white-space: pre-wrap; word-wrap: break-word; margin: 1px 0; padding: 10px; background-color: #fff; border-radius: 5px; border: 1px solid #ddd; box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);">{content}</pre>
155
+ </details>
156
+ """
157
+
158
+ # Add dummy blanks after the last source
159
+ dummy_blanks = """
160
+ <div style="margin: 20px 0;"></div>
161
+ <div style="margin: 20px 0;"></div>
162
+ <div style="margin: 20px 0;"></div>
163
+ <div style="margin: 20px 0;"></div>
164
+ <div style="margin: 20px 0;"></div>
165
+ """
166
+
167
+ cited_text += sources_list + dummy_blanks
168
+ return cited_text
169
 
170
+ def infer_image_and_get_boxes(image, confidence_threshold=0.8):
171
+ results = model.predict(image)
172
+ return [
173
+ (int(box.xyxy[0][0]), int(box.xyxy[0][1]), int(box.xyxy[0][2]), int(box.xyxy[0][3]), int(box.cls[0]))
174
+ for result in results for box in result.boxes
175
+ if int(box.cls[0]) in {figure_class_index, table_class_index} and box.conf[0] > confidence_threshold
176
+ ]
177
+
178
+ def crop_images_from_boxes(image, boxes, scale_factor):
179
+ figures = []
180
+ tables = []
181
+ for (x1, y1, x2, y2, cls) in boxes:
182
+ cropped_img = image[int(y1 * scale_factor):int(y2 * scale_factor), int(x1 * scale_factor):int(x2 * scale_factor)]
183
+ if cls == figure_class_index:
184
+ figures.append(cropped_img)
185
+ elif cls == table_class_index:
186
+ tables.append(cropped_img)
187
+ return figures, tables
188
+
189
+ def process_pdf(pdf_file_path):
190
+ doc = fitz.open(pdf_file_path)
191
  all_figures = []
192
  all_tables = []
193
+ low_dpi = 50
194
+ high_dpi = 300
195
+ scale_factor = high_dpi / low_dpi
196
+ low_res_pixmaps = [page.get_pixmap(dpi=low_dpi) for page in doc]
197
 
198
+ for page_num, low_res_pix in enumerate(low_res_pixmaps):
 
199
  low_res_img = np.frombuffer(low_res_pix.samples, dtype=np.uint8).reshape(low_res_pix.height, low_res_pix.width, 3)
200
+ boxes = infer_image_and_get_boxes(low_res_img)
 
 
 
 
 
 
 
201
 
202
  if boxes:
203
+ high_res_pix = doc[page_num].get_pixmap(dpi=high_dpi)
204
  high_res_img = np.frombuffer(high_res_pix.samples, dtype=np.uint8).reshape(high_res_pix.height, high_res_pix.width, 3)
205
+ figures, tables = crop_images_from_boxes(high_res_img, boxes, scale_factor)
206
+ all_figures.extend(figures)
207
+ all_tables.extend(tables)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
+ return all_figures, all_tables
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
+ def image_to_base64(img):
212
+ buffered = io.BytesIO()
213
+ img = Image.fromarray(img)
214
+ img.save(buffered, format="PNG")
215
+ return base64.b64encode(buffered.getvalue()).decode()
216
+
217
+ def on_btn_click():
218
+ del st.session_state.chat_history[:]
219
+
220
+ # Streamlit interface
221
+
222
+ # Custom CSS for the file uploader
223
+ uploadercss='''
224
  <style>
225
+ [data-testid='stFileUploader'] {
226
+ width: max-content;
 
227
  }
228
+ [data-testid='stFileUploader'] section {
229
+ padding: 0;
230
+ float: left;
 
 
231
  }
232
+ [data-testid='stFileUploader'] section > input + div {
233
+ display: none;
234
+ }
235
+ [data-testid='stFileUploader'] section + div {
236
+ float: right;
237
+ padding-top: 0;
238
  }
239
  </style>
240
+ '''
241
+
242
+ st.set_page_config(page_title="PDF Reading Assistant", page_icon="πŸ“„")
243
+
244
+ # Initialize chat history in session state if not already present
245
+ if 'chat_history' not in st.session_state:
246
+ st.session_state.chat_history = []
247
+
248
+ st.title("πŸ“„ PDF Reading Assistant")
249
+ st.markdown("### Extract tables, figures, summaries, and answers from your PDF files easily.")
250
+ chat_placeholder = st.empty()
251
+
252
+ # File uploader for PDF
253
+ uploaded_file = st.file_uploader("Upload a PDF", type="pdf")
254
+ st.markdown(uploadercss, unsafe_allow_html=True)
255
+ if uploaded_file:
256
+ file_path = save_uploaded_file(uploaded_file)
257
+
258
+ # Chat container where all messages will be displayed
259
+ chat_container = st.container()
260
+ user_input = st.chat_input("Ask a question about the pdf......", key="user_input")
261
+ with chat_container:
262
+ # Scrollable chat messages
263
+ for idx, chat in enumerate(st.session_state.chat_history):
264
+ if chat.get("user"):
265
+ message(chat["user"], is_user=True, allow_html=True, key=f"user_{idx}", avatar_style="initials", seed="user")
266
+ if chat.get("bot"):
267
+ message(chat["bot"], is_user=False, allow_html=True, key=f"bot_{idx}",seed="bot")
268
+
269
+ # Input area and buttons for user interaction
270
+ with st.form(key="chat_form", clear_on_submit=True,border=False):
271
+
272
+ col1, col2, col3 = st.columns([1, 1, 1])
273
+ with col1:
274
+ summary_button = st.form_submit_button("Generate Summary")
275
+ with col2:
276
+ extract_button = st.form_submit_button("Extract Tables and Figures")
277
+ with col3:
278
+ st.form_submit_button("Clear message", on_click=on_btn_click)
279
+
280
+ # Handle responses based on user input and button presses
281
+ if summary_button:
282
+ with st.spinner("Generating summary..."):
283
+ summary = summarize_pdf(file_path)
284
+ st.session_state.chat_history.append({"user": "Generate Summary", "bot": summary})
285
+ st.rerun()
286
+
287
+ if extract_button:
288
+ with st.spinner("Extracting tables and figures..."):
289
+ figures, tables = process_pdf(file_path)
290
+ if figures:
291
+ st.session_state.chat_history.append({"user": "Figures"})
292
+
293
+ for idx, figure in enumerate(figures):
294
+ figure_base64 = image_to_base64(figure)
295
+ result_html = f'<img src="data:image/png;base64,{figure_base64}" style="width:100%; display:block;" alt="Figure {idx+1}"/>'
296
+ st.session_state.chat_history.append({"bot": f"Figure {idx+1} {result_html}"})
297
+ if tables:
298
+ st.session_state.chat_history.append({"user": "Tables"})
299
+ for idx, table in enumerate(tables):
300
+ table_base64 = image_to_base64(table)
301
+ result_html = f'<img src="data:image/png;base64,{table_base64}" style="width:100%; display:block;" alt="Table {idx+1}"/>'
302
+ st.session_state.chat_history.append({"bot": f"Table {idx+1} {result_html}"})
303
+ st.rerun()
304
+
305
+ if user_input:
306
+ st.session_state.chat_history.append({"user": user_input, "bot": None})
307
+ with st.spinner("Processing..."):
308
+ answer = qa_pdf(file_path, user_input)
309
+ st.session_state.chat_history[-1]["bot"] = answer
310
+ st.rerun()
311
+
312
+ # Additional CSS and JavaScript to ensure the chat container is scrollable and scrolls to the bottom
313
+ st.markdown("""
314
+ <style>
315
+ #chat-container {
316
+ max-height: 500px;
317
+ overflow-y: auto;
318
+ padding: 1rem;
319
+ border: 1px solid #ddd;
320
+ border-radius: 8px;
321
+ background-color: #fefefe;
322
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
323
+ transition: background-color 0.3s ease;
324
+ }
325
+ #chat-container:hover {
326
+ background-color: #f9f9f9;
327
+ }
328
+ .stChatMessage {
329
+ padding: 0.75rem;
330
+ margin: 0.75rem 0;
331
+ border-radius: 8px;
332
+ box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1);
333
+ transition: background-color 0.3s ease;
334
+ }
335
+ .stChatMessage--user {
336
+ background-color: #E3F2FD;
337
+ }
338
+ .stChatMessage--user:hover {
339
+ background-color: #BBDEFB;
340
+ }
341
+ .stChatMessage--bot {
342
+ background-color: #EDE7F6;
343
+ }
344
+ .stChatMessage--bot:hover {
345
+ background-color: #D1C4E9;
346
+ }
347
+ textarea {
348
+ width: 100%;
349
+ padding: 1rem;
350
+ border: 1px solid #ddd;
351
+ border-radius: 8px;
352
+ box-shadow: inset 0 1px 3px rgba(0, 0, 0, 0.1);
353
+ transition: border-color 0.3s ease, box-shadow 0.3s ease;
354
+ }
355
+ textarea:focus {
356
+ border-color: #4CAF50;
357
+ box-shadow: 0 0 5px rgba(76, 175, 80, 0.5);
358
+ }
359
+ .stButton > button {
360
+ width: 100%;
361
+ background-color: #4CAF50;
362
+ color: white;
363
+ border: none;
364
+ border-radius: 8px;
365
+ padding: 0.75rem;
366
+ font-size: 16px;
367
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
368
+ transition: background-color 0.3s ease, box-shadow 0.3s ease;
369
+ }
370
+ .stButton > button:hover {
371
+ background-color: #45A049;
372
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
373
+ }
374
+ </style>
375
+ <script>
376
+ const chatContainer = document.getElementById('chat-container');
377
+ chatContainer.scrollTop = chatContainer.scrollHeight;
378
+ </script>
379
+ """, unsafe_allow_html=True)