zliang commited on
Commit
89f2ae3
·
verified ·
1 Parent(s): 0476da0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +175 -332
app.py CHANGED
@@ -1,7 +1,3 @@
1
-
2
-
3
-
4
-
5
  import os
6
  os.system("python -m spacy download en_core_web_sm")
7
  import io
@@ -17,21 +13,29 @@ from langchain_core.output_parsers import StrOutputParser
17
  from langchain_community.document_loaders import PyMuPDFLoader
18
  from langchain_openai import OpenAIEmbeddings
19
  from langchain_text_splitters import RecursiveCharacterTextSplitter
20
- from langchain_text_splitters import SpacyTextSplitter
21
  from langchain_core.prompts import ChatPromptTemplate
22
  from langchain_openai import ChatOpenAI
23
  import re
24
  from PIL import Image
25
- from streamlit_chat import message
26
 
27
- # Load the trained model
 
 
 
 
 
 
 
28
 
29
- model = YOLO("best.pt")
30
- openai_api_key = os.environ.get("openai_api_key")
31
 
32
- # Define the class indices for figures, tables, and text
33
- figure_class_index = 4
34
- table_class_index = 3
 
 
 
35
 
36
  # Utility functions
37
  def clean_text(text):
@@ -39,343 +43,182 @@ def clean_text(text):
39
 
40
  def remove_references(text):
41
  reference_patterns = [
42
- r'\bReferences\b', r'\breferences\b', r'\bBibliography\b', r'\bCitations\b',
43
- r'\bWorks Cited\b', r'\bReference\b', r'\breference\b'
44
  ]
45
- lines = text.split('\n')
46
- for i, line in enumerate(lines):
47
- if any(re.search(pattern, line, re.IGNORECASE) for pattern in reference_patterns):
48
- return '\n'.join(lines[:i])
49
- return text
50
-
51
- def save_uploaded_file(uploaded_file):
52
- temp_file = tempfile.NamedTemporaryFile(delete=False)
53
- temp_file.write(uploaded_file.getbuffer())
54
- temp_file.close()
55
- return temp_file.name
56
-
57
- def summarize_pdf(pdf_file_path, num_clusters=10):
58
- embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small", api_key=openai_api_key)
59
- llm = ChatOpenAI(model="gpt-4o-mini", api_key=openai_api_key, temperature=0.3)
60
- prompt = ChatPromptTemplate.from_template(
61
- """Could you please provide a concise and comprehensive summary of the given Contexts?
62
- The summary should capture the main points and key details of the text while conveying the author's intended meaning accurately.
63
- Please ensure that the summary is well-organized and easy to read, with clear headings and subheadings to guide the reader through each section.
64
- The length of the summary should be appropriate to capture the main points and key details of the text, without including unnecessary information or becoming overly long.
65
- example of summary:
66
- ## Summary:
67
- ## Key points:
68
- Contexts: {topic}"""
69
- )
70
- output_parser = StrOutputParser()
71
- chain = prompt | llm | output_parser
72
 
73
- loader = PyMuPDFLoader(pdf_file_path)
 
 
 
74
  docs = loader.load()
75
  full_text = "\n".join(doc.page_content for doc in docs)
76
- cleaned_full_text = clean_text(remove_references(full_text))
77
- text_splitter = SpacyTextSplitter(chunk_size=500)
78
- #text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0, separators=["\n\n", "\n", ".", " "])
79
- split_contents = text_splitter.split_text(cleaned_full_text)
80
- embeddings = embeddings_model.embed_documents(split_contents)
81
-
82
- kmeans = KMeans(n_clusters=num_clusters, init='k-means++', random_state=0).fit(embeddings)
83
- closest_point_indices = [np.argmin(np.linalg.norm(embeddings - center, axis=1)) for center in kmeans.cluster_centers_]
84
- extracted_contents = [split_contents[idx] for idx in closest_point_indices]
85
-
86
- results = chain.invoke({"topic": ' '.join(extracted_contents)})
87
-
88
- return generate_citations(results, extracted_contents)
89
-
90
- def qa_pdf(pdf_file_path, query, num_clusters=5, similarity_threshold=0.6):
91
- embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small", api_key=openai_api_key)
92
- llm = ChatOpenAI(model="gpt-4o-mini", api_key=openai_api_key, temperature=0.3)
93
- prompt = ChatPromptTemplate.from_template(
94
- """Please provide a detailed and accurate answer to the given question based on the provided contexts.
95
- Ensure that the answer is comprehensive and directly addresses the query.
96
- If necessary, include relevant examples or details from the text.
97
- Question: {question}
98
- Contexts: {contexts}"""
99
  )
100
- output_parser = StrOutputParser()
101
- chain = prompt | llm | output_parser
102
-
103
- loader = PyMuPDFLoader(pdf_file_path)
104
- docs = loader.load()
105
- full_text = "\n".join(doc.page_content for doc in docs)
106
- cleaned_full_text = clean_text(remove_references(full_text))
107
- text_splitter = SpacyTextSplitter(chunk_size=500)
108
-
109
- #text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=0, separators=["\n\n", "\n", ".", " "])
110
- split_contents = text_splitter.split_text(cleaned_full_text)
111
- embeddings = embeddings_model.embed_documents(split_contents)
112
-
113
- query_embedding = embeddings_model.embed_query(query)
114
- similarity_scores = cosine_similarity([query_embedding], embeddings)[0]
115
- top_indices = np.argsort(similarity_scores)[-num_clusters:]
116
- relevant_contents = [split_contents[i] for i in top_indices]
117
-
118
- results = chain.invoke({"question": query, "contexts": ' '.join(relevant_contents)})
119
-
120
- return generate_citations(results, relevant_contents, similarity_threshold)
121
-
122
- def generate_citations(text, contents, similarity_threshold=0.6):
123
- embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small", api_key=openai_api_key)
124
- text_sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
125
- text_embeddings = embeddings_model.embed_documents(text_sentences)
126
- content_embeddings = embeddings_model.embed_documents(contents)
127
- similarity_matrix = cosine_similarity(text_embeddings, content_embeddings)
128
-
129
- cited_text = text
130
- relevant_sources = []
131
- source_mapping = {}
132
- sentence_to_source = {}
133
-
134
- for i, sentence in enumerate(text_sentences):
135
- if sentence in sentence_to_source:
136
- continue
137
- max_similarity = max(similarity_matrix[i])
138
- if max_similarity >= similarity_threshold:
139
- most_similar_idx = np.argmax(similarity_matrix[i])
140
- if most_similar_idx not in source_mapping:
141
- source_mapping[most_similar_idx] = len(relevant_sources) + 1
142
- relevant_sources.append((most_similar_idx, contents[most_similar_idx]))
143
- citation_idx = source_mapping[most_similar_idx]
144
- citation = f"([Source {citation_idx}](#source-{citation_idx}))"
145
- cited_sentence = re.sub(r'([.!?])$', f" {citation}\\1", sentence)
146
- sentence_to_source[sentence] = citation_idx
147
- cited_text = cited_text.replace(sentence, cited_sentence)
148
-
149
- sources_list = "\n\n## Sources:\n"
150
- for idx, (original_idx, content) in enumerate(relevant_sources):
151
- sources_list += f"""
152
- <details style="margin: 1px 0; padding: 5px; border: 1px solid #ccc; border-radius: 8px; background-color: #f9f9f9; transition: all 0.3s ease;">
153
- <summary style="font-weight: bold; cursor: pointer; outline: none; padding: 5px 0; transition: color 0.3s ease;">Source {idx + 1}</summary>
154
- <pre style="white-space: pre-wrap; word-wrap: break-word; margin: 1px 0; padding: 10px; background-color: #fff; border-radius: 5px; border: 1px solid #ddd; box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);">{content}</pre>
155
- </details>
156
- """
157
-
158
- # Add dummy blanks after the last source
159
- dummy_blanks = """
160
- <div style="margin: 20px 0;"></div>
161
- <div style="margin: 20px 0;"></div>
162
- <div style="margin: 20px 0;"></div>
163
- <div style="margin: 20px 0;"></div>
164
- <div style="margin: 20px 0;"></div>
165
-
166
- """
167
-
168
- cited_text += sources_list + dummy_blanks
169
- return cited_text
170
-
171
- def infer_image_and_get_boxes(image, confidence_threshold=0.8):
172
- results = model.predict(image)
173
- return [
174
- (int(box.xyxy[0][0]), int(box.xyxy[0][1]), int(box.xyxy[0][2]), int(box.xyxy[0][3]), int(box.cls[0]))
175
- for result in results for box in result.boxes
176
- if int(box.cls[0]) in {figure_class_index, table_class_index} and box.conf[0] > confidence_threshold
177
- ]
178
-
179
- def crop_images_from_boxes(image, boxes, scale_factor):
180
- figures = []
181
- tables = []
182
- for (x1, y1, x2, y2, cls) in boxes:
183
- cropped_img = image[int(y1 * scale_factor):int(y2 * scale_factor), int(x1 * scale_factor):int(x2 * scale_factor)]
184
- if cls == figure_class_index:
185
- figures.append(cropped_img)
186
- elif cls == table_class_index:
187
- tables.append(cropped_img)
188
- return figures, tables
189
 
190
- def process_pdf(pdf_file_path):
191
- doc = fitz.open(pdf_file_path)
 
 
192
  all_figures = []
193
  all_tables = []
194
- low_dpi = 50
195
- high_dpi = 300
196
- scale_factor = high_dpi / low_dpi
197
- low_res_pixmaps = [page.get_pixmap(dpi=low_dpi) for page in doc]
198
 
199
- for page_num, low_res_pix in enumerate(low_res_pixmaps):
 
200
  low_res_img = np.frombuffer(low_res_pix.samples, dtype=np.uint8).reshape(low_res_pix.height, low_res_pix.width, 3)
201
- boxes = infer_image_and_get_boxes(low_res_img)
 
 
 
 
 
 
 
202
 
203
  if boxes:
204
- high_res_pix = doc[page_num].get_pixmap(dpi=high_dpi)
205
  high_res_img = np.frombuffer(high_res_pix.samples, dtype=np.uint8).reshape(high_res_pix.height, high_res_pix.width, 3)
206
- figures, tables = crop_images_from_boxes(high_res_img, boxes, scale_factor)
207
- all_figures.extend(figures)
208
- all_tables.extend(tables)
 
 
 
 
209
 
210
- return all_figures, all_tables
211
-
212
- def image_to_base64(img):
213
- buffered = io.BytesIO()
214
- img = Image.fromarray(img)
215
- img.save(buffered, format="PNG")
216
- return base64.b64encode(buffered.getvalue()).decode()
217
-
218
- def on_btn_click():
219
- del st.session_state.chat_history[:]
220
-
221
- # Streamlit interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
- # Custom CSS for the file uploader
224
- uploadercss='''
225
  <style>
226
- [data-testid='stFileUploader'] {
227
- width: max-content;
228
- }
229
- [data-testid='stFileUploader'] section {
230
- padding: 0;
231
- float: left;
232
  }
233
- [data-testid='stFileUploader'] section > input + div {
234
- display: none;
 
 
 
235
  }
236
- [data-testid='stFileUploader'] section + div {
237
- float: right;
238
- padding-top: 0;
239
  }
240
-
241
  </style>
242
- '''
243
-
244
- st.set_page_config(page_title="PDF Reading Assistant", page_icon="📄")
245
-
246
- # Initialize chat history in session state if not already present
247
- if 'chat_history' not in st.session_state:
248
- st.session_state.chat_history = []
249
-
250
- st.title("📄 PDF Reading Assistant")
251
- st.markdown("### Extract tables, figures, summaries, and answers from your PDF files easily.")
252
- chat_placeholder = st.empty()
253
-
254
- # File uploader for PDF
255
- uploaded_file = st.file_uploader("Upload a PDF", type="pdf")
256
- st.markdown(uploadercss, unsafe_allow_html=True)
257
- if uploaded_file:
258
- file_path = save_uploaded_file(uploaded_file)
259
-
260
- # Chat container where all messages will be displayed
261
- chat_container = st.container()
262
- user_input = st.chat_input("Ask a question about the pdf......", key="user_input")
263
- with chat_container:
264
- # Scrollable chat messages
265
- for idx, chat in enumerate(st.session_state.chat_history):
266
- if chat.get("user"):
267
- message(chat["user"], is_user=True, allow_html=True, key=f"user_{idx}", avatar_style="initials", seed="user")
268
- if chat.get("bot"):
269
- message(chat["bot"], is_user=False, allow_html=True, key=f"bot_{idx}",seed="bot")
270
-
271
- # Input area and buttons for user interaction
272
- with st.form(key="chat_form", clear_on_submit=True,border=False):
273
-
274
- col1, col2, col3 = st.columns([1, 1, 1])
275
- with col1:
276
- summary_button = st.form_submit_button("Generate Summary")
277
- with col2:
278
- extract_button = st.form_submit_button("Extract Tables and Figures")
279
- with col3:
280
- st.form_submit_button("Clear message", on_click=on_btn_click)
281
-
282
- # Handle responses based on user input and button presses
283
- if summary_button:
284
- with st.spinner("Generating summary..."):
285
- summary = summarize_pdf(file_path)
286
- st.session_state.chat_history.append({"user": "Generate Summary", "bot": summary})
287
- st.rerun()
288
-
289
- if extract_button:
290
- with st.spinner("Extracting tables and figures..."):
291
- figures, tables = process_pdf(file_path)
292
- if figures:
293
- st.session_state.chat_history.append({"user": "Figures"})
294
-
295
- for idx, figure in enumerate(figures):
296
- figure_base64 = image_to_base64(figure)
297
- result_html = f'<img src="data:image/png;base64,{figure_base64}" style="width:100%; display:block;" alt="Figure {idx+1}"/>'
298
- st.session_state.chat_history.append({"bot": f"Figure {idx+1} {result_html}"})
299
- if tables:
300
- st.session_state.chat_history.append({"user": "Tables"})
301
- for idx, table in enumerate(tables):
302
- table_base64 = image_to_base64(table)
303
- result_html = f'<img src="data:image/png;base64,{table_base64}" style="width:100%; display:block;" alt="Table {idx+1}"/>'
304
- st.session_state.chat_history.append({"bot": f"Table {idx+1} {result_html}"})
305
- st.rerun()
306
-
307
- if user_input:
308
- st.session_state.chat_history.append({"user": user_input, "bot": None})
309
- with st.spinner("Processing..."):
310
- answer = qa_pdf(file_path, user_input)
311
- st.session_state.chat_history[-1]["bot"] = answer
312
- st.rerun()
313
-
314
- # Additional CSS and JavaScript to ensure the chat container is scrollable and scrolls to the bottom
315
- st.markdown("""
316
- <style>
317
- #chat-container {
318
- max-height: 500px;
319
- overflow-y: auto;
320
- padding: 1rem;
321
- border: 1px solid #ddd;
322
- border-radius: 8px;
323
- background-color: #fefefe;
324
- box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
325
- transition: background-color 0.3s ease;
326
- }
327
- #chat-container:hover {
328
- background-color: #f9f9f9;
329
- }
330
- .stChatMessage {
331
- padding: 0.75rem;
332
- margin: 0.75rem 0;
333
- border-radius: 8px;
334
- box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1);
335
- transition: background-color 0.3s ease;
336
- }
337
- .stChatMessage--user {
338
- background-color: #E3F2FD;
339
- }
340
- .stChatMessage--user:hover {
341
- background-color: #BBDEFB;
342
- }
343
- .stChatMessage--bot {
344
- background-color: #EDE7F6;
345
- }
346
- .stChatMessage--bot:hover {
347
- background-color: #D1C4E9;
348
- }
349
- textarea {
350
- width: 100%;
351
- padding: 1rem;
352
- border: 1px solid #ddd;
353
- border-radius: 8px;
354
- box-shadow: inset 0 1px 3px rgba(0, 0, 0, 0.1);
355
- transition: border-color 0.3s ease, box-shadow 0.3s ease;
356
- }
357
- textarea:focus {
358
- border-color: #4CAF50;
359
- box-shadow: 0 0 5px rgba(76, 175, 80, 0.5);
360
- }
361
- .stButton > button {
362
- width: 100%;
363
- background-color: #4CAF50;
364
- color: white;
365
- border: none;
366
- border-radius: 8px;
367
- padding: 0.75rem;
368
- font-size: 16px;
369
- box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
370
- transition: background-color 0.3s ease, box-shadow 0.3s ease;
371
- }
372
- .stButton > button:hover {
373
- background-color: #45A049;
374
- box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
375
- }
376
- </style>
377
- <script>
378
- const chatContainer = document.getElementById('chat-container');
379
- chatContainer.scrollTop = chatContainer.scrollHeight;
380
- </script>
381
- """, unsafe_allow_html=True)
 
 
 
 
 
1
  import os
2
  os.system("python -m spacy download en_core_web_sm")
3
  import io
 
13
  from langchain_community.document_loaders import PyMuPDFLoader
14
  from langchain_openai import OpenAIEmbeddings
15
  from langchain_text_splitters import RecursiveCharacterTextSplitter
 
16
  from langchain_core.prompts import ChatPromptTemplate
17
  from langchain_openai import ChatOpenAI
18
  import re
19
  from PIL import Image
 
20
 
21
+ # Cached resources
22
+ @st.cache_resource
23
+ def load_models():
24
+ return {
25
+ "yolo": YOLO("best.pt"),
26
+ "embeddings": OpenAIEmbeddings(model="text-embedding-3-small"),
27
+ "llm": ChatOpenAI(model="gpt-4-turbo", temperature=0.3)
28
+ }
29
 
30
+ models = load_models()
31
+ openai_api_key = os.environ.get("OPENAI_API_KEY")
32
 
33
+ # Constants
34
+ FIGURE_CLASS_INDEX = 4
35
+ TABLE_CLASS_INDEX = 3
36
+ CHUNK_SIZE = 1000
37
+ CHUNK_OVERLAP = 200
38
+ NUM_CLUSTERS = 8
39
 
40
  # Utility functions
41
  def clean_text(text):
 
43
 
44
  def remove_references(text):
45
  reference_patterns = [
46
+ r'\bReferences\b', r'\breferences\b', r'\bBibliography\b',
47
+ r'\bCitations\b', r'\bWorks Cited\b'
48
  ]
49
+ return re.sub('|'.join(reference_patterns), '', text, flags=re.IGNORECASE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ @st.cache_data
52
+ def process_pdf(file_path):
53
+ """Process PDF once and cache results"""
54
+ loader = PyMuPDFLoader(file_path)
55
  docs = loader.load()
56
  full_text = "\n".join(doc.page_content for doc in docs)
57
+ cleaned_text = clean_text(remove_references(full_text))
58
+
59
+ text_splitter = RecursiveCharacterTextSplitter(
60
+ chunk_size=CHUNK_SIZE,
61
+ chunk_overlap=CHUNK_OVERLAP,
62
+ separators=["\n\n", "\n", ". ", "! ", "? ", " "]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  )
64
+ split_contents = text_splitter.split_text(cleaned_text)
65
+
66
+ return {
67
+ "text": cleaned_text,
68
+ "chunks": split_contents,
69
+ "embeddings": models["embeddings"].embed_documents(split_contents)
70
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ @st.cache_data
73
+ def extract_visuals(file_path):
74
+ """Extract figures and tables with caching"""
75
+ doc = fitz.open(file_path)
76
  all_figures = []
77
  all_tables = []
 
 
 
 
78
 
79
+ for page in doc:
80
+ low_res_pix = page.get_pixmap(dpi=50)
81
  low_res_img = np.frombuffer(low_res_pix.samples, dtype=np.uint8).reshape(low_res_pix.height, low_res_pix.width, 3)
82
+
83
+ results = models["yolo"].predict(low_res_img)
84
+ boxes = [
85
+ (int(box.xyxy[0][0]), int(box.xyxy[0][1]),
86
+ int(box.xyxy[0][2]), int(box.xyxy[0][3]), int(box.cls[0]))
87
+ for result in results for box in result.boxes
88
+ if box.conf[0] > 0.8 and int(box.cls[0]) in {FIGURE_CLASS_INDEX, TABLE_CLASS_INDEX}
89
+ ]
90
 
91
  if boxes:
92
+ high_res_pix = page.get_pixmap(dpi=300)
93
  high_res_img = np.frombuffer(high_res_pix.samples, dtype=np.uint8).reshape(high_res_pix.height, high_res_pix.width, 3)
94
+
95
+ for x1, y1, x2, y2, cls in boxes:
96
+ img = high_res_img[int(y1*6):int(y2*6), int(x1*6):int(x2*6)]
97
+ if cls == FIGURE_CLASS_INDEX:
98
+ all_figures.append(img)
99
+ else:
100
+ all_tables.append(img)
101
 
102
+ return {"figures": all_figures, "tables": all_tables}
103
+
104
+ def generate_summary(chunks, embeddings):
105
+ """Generate summary using clustered chunks"""
106
+ kmeans = KMeans(n_clusters=NUM_CLUSTERS, init='k-means++').fit(embeddings)
107
+ cluster_indices = [np.argmin(np.linalg.norm(embeddings - center, axis=1))
108
+ for center in kmeans.cluster_centers_]
109
+ selected_chunks = [chunks[i] for i in cluster_indices]
110
+
111
+ prompt = ChatPromptTemplate.from_template(
112
+ """Create a structured summary with key points from these context sections:
113
+ {contexts}
114
+ Format:
115
+ ## Summary
116
+ [concise overview]
117
+ ## Key Points
118
+ - [main point 1]
119
+ - [main point 2]
120
+ ..."""
121
+ )
122
+ chain = prompt | models["llm"] | StrOutputParser()
123
+ return chain.invoke({"contexts": '\n\n'.join(selected_chunks)})
124
+
125
+ def answer_question(question, chunks, embeddings):
126
+ """Answer question using semantic search"""
127
+ query_embedding = models["embeddings"].embed_query(question)
128
+ similarities = cosine_similarity([query_embedding], embeddings)[0]
129
+ top_indices = np.argsort(similarities)[-5:][::-1]
130
+ context = '\n'.join([chunks[i] for i in top_indices if similarities[i] > 0.6])
131
+
132
+ prompt = ChatPromptTemplate.from_template(
133
+ """Answer this question: {question}
134
+ Using only this context: {context}
135
+ - Be precise and include relevant details
136
+ - Cite sources as [Source 1], [Source 2], etc."""
137
+ )
138
+ chain = prompt | models["llm"] | StrOutputParser()
139
+ return chain.invoke({"question": question, "context": context})
140
+
141
+ # Streamlit UI
142
+ st.set_page_config(page_title="PDF Assistant", layout="wide")
143
+ st.title("📄 Smart PDF Assistant")
144
+
145
+ if "chat" not in st.session_state:
146
+ st.session_state.chat = []
147
+ if "processed_data" not in st.session_state:
148
+ st.session_state.processed_data = None
149
+
150
+ # File upload section
151
+ with st.sidebar:
152
+ uploaded_file = st.file_uploader("Upload PDF", type="pdf")
153
+ if uploaded_file:
154
+ with tempfile.NamedTemporaryFile(delete=False) as tmp:
155
+ tmp.write(uploaded_file.getbuffer())
156
+ st.session_state.processed_data = process_pdf(tmp.name)
157
+ visuals = extract_visuals(tmp.name)
158
+
159
+ # Chat interface
160
+ col1, col2 = st.columns([3, 1])
161
+ with col1:
162
+ st.subheader("Document Interaction")
163
+
164
+ for msg in st.session_state.chat:
165
+ with st.chat_message(msg["role"]):
166
+ if "image" in msg:
167
+ st.image(msg["image"], caption=msg.get("caption"))
168
+ else:
169
+ st.markdown(msg["content"])
170
+
171
+ if prompt := st.chat_input("Ask about the document..."):
172
+ st.session_state.chat.append({"role": "user", "content": prompt})
173
+ with st.spinner("Analyzing..."):
174
+ response = answer_question(
175
+ prompt,
176
+ st.session_state.processed_data["chunks"],
177
+ st.session_state.processed_data["embeddings"]
178
+ )
179
+ st.session_state.chat.append({"role": "assistant", "content": response})
180
+ st.rerun()
181
+
182
+ with col2:
183
+ st.subheader("Document Insights")
184
+
185
+ if st.button("Generate Summary"):
186
+ with st.spinner("Summarizing..."):
187
+ summary = generate_summary(
188
+ st.session_state.processed_data["chunks"],
189
+ st.session_state.processed_data["embeddings"]
190
+ )
191
+ st.session_state.chat.append({
192
+ "role": "assistant",
193
+ "content": f"## Document Summary\n{summary}"
194
+ })
195
+ st.rerun()
196
+
197
+ if visuals["figures"]:
198
+ with st.expander(f"📷 Figures ({len(visuals['figures'])})"):
199
+ for idx, fig in enumerate(visuals["figures"], 1):
200
+ st.image(fig, caption=f"Figure {idx}")
201
+
202
+ if visuals["tables"]:
203
+ with st.expander(f"📊 Tables ({len(visuals['tables'])})"):
204
+ for idx, tbl in enumerate(visuals["tables"], 1):
205
+ st.image(tbl, caption=f"Table {idx}")
206
 
207
+ # Custom styling
208
+ st.markdown("""
209
  <style>
210
+ [data-testid=stSidebar] {
211
+ background: #fafafa;
212
+ border-right: 1px solid #eee;
 
 
 
213
  }
214
+ .stChatMessage {
215
+ padding: 1rem;
216
+ margin: 0.5rem 0;
217
+ border-radius: 10px;
218
+ box-shadow: 0 2px 5px rgba(0,0,0,0.1);
219
  }
220
+ [data-testid=stVerticalBlock] > div:has(>.stChatMessage) {
221
+ gap: 0.5rem;
 
222
  }
 
223
  </style>
224
+ """, unsafe_allow_html=True)