zliang commited on
Commit
c6a9f47
·
verified ·
1 Parent(s): 5599ea4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +241 -302
app.py CHANGED
@@ -1,46 +1,42 @@
1
 
2
 
3
-
4
-
5
  import os
6
- os.system("python -m spacy download en_core_web_sm")
7
  import io
8
  import base64
9
- import streamlit as st
10
  import numpy as np
11
  import fitz # PyMuPDF
12
  import tempfile
13
- from ultralytics import YOLO
14
  from sklearn.cluster import KMeans
15
  from sklearn.metrics.pairwise import cosine_similarity
 
 
 
16
  from langchain_core.output_parsers import StrOutputParser
17
  from langchain_community.document_loaders import PyMuPDFLoader
18
- from langchain_openai import OpenAIEmbeddings
19
- from langchain_text_splitters import RecursiveCharacterTextSplitter
20
  from langchain_text_splitters import SpacyTextSplitter
21
  from langchain_core.prompts import ChatPromptTemplate
22
- from langchain_openai import ChatOpenAI
23
- import re
24
- from PIL import Image
25
- from streamlit_chat import message
26
-
27
- # Load the trained model
28
 
 
 
29
  model = YOLO("best.pt")
30
  openai_api_key = os.environ.get("openai_api_key")
31
-
32
- # Define the class indices for figures, tables, and text
33
- figure_class_index = 4
34
- table_class_index = 3
35
 
36
  # Utility functions
 
37
  def clean_text(text):
38
  return re.sub(r'\s+', ' ', text).strip()
39
 
40
  def remove_references(text):
41
  reference_patterns = [
42
- r'\bReferences\b', r'\breferences\b', r'\bBibliography\b', r'\bCitations\b',
43
- r'\bWorks Cited\b', r'\bReference\b', r'\breference\b'
44
  ]
45
  lines = text.split('\n')
46
  for i, line in enumerate(lines):
@@ -48,332 +44,275 @@ def remove_references(text):
48
  return '\n'.join(lines[:i])
49
  return text
50
 
51
- def save_uploaded_file(uploaded_file):
52
- temp_file = tempfile.NamedTemporaryFile(delete=False)
53
- temp_file.write(uploaded_file.getbuffer())
54
- temp_file.close()
55
- return temp_file.name
56
-
57
- def summarize_pdf(pdf_file_path, num_clusters=10):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small", api_key=openai_api_key)
59
- llm = ChatOpenAI(model="gpt-4o-mini", api_key=openai_api_key, temperature=0.3)
 
60
  prompt = ChatPromptTemplate.from_template(
61
- """Could you please provide a concise and comprehensive summary of the given Contexts?
62
- The summary should capture the main points and key details of the text while conveying the author's intended meaning accurately.
63
- Please ensure that the summary is well-organized and easy to read, with clear headings and subheadings to guide the reader through each section.
64
- The length of the summary should be appropriate to capture the main points and key details of the text, without including unnecessary information or becoming overly long.
65
- example of summary:
66
- ## Summary:
67
- ## Key points:
68
- Contexts: {topic}"""
69
  )
70
- output_parser = StrOutputParser()
71
- chain = prompt | llm | output_parser
72
-
73
- loader = PyMuPDFLoader(pdf_file_path)
74
  docs = loader.load()
75
  full_text = "\n".join(doc.page_content for doc in docs)
76
  cleaned_full_text = clean_text(remove_references(full_text))
 
77
  text_splitter = SpacyTextSplitter(chunk_size=500)
78
- #text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0, separators=["\n\n", "\n", ".", " "])
79
  split_contents = text_splitter.split_text(cleaned_full_text)
 
80
  embeddings = embeddings_model.embed_documents(split_contents)
 
 
 
 
 
 
81
 
82
- kmeans = KMeans(n_clusters=num_clusters, init='k-means++', random_state=0).fit(embeddings)
83
- closest_point_indices = [np.argmin(np.linalg.norm(embeddings - center, axis=1)) for center in kmeans.cluster_centers_]
84
- extracted_contents = [split_contents[idx] for idx in closest_point_indices]
85
-
86
- results = chain.invoke({"topic": ' '.join(extracted_contents)})
87
-
88
- return generate_citations(results, extracted_contents)
89
-
90
- def qa_pdf(pdf_file_path, query, num_clusters=5, similarity_threshold=0.6):
91
  embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small", api_key=openai_api_key)
92
- llm = ChatOpenAI(model="gpt-4o-mini", api_key=openai_api_key, temperature=0.3)
 
93
  prompt = ChatPromptTemplate.from_template(
94
- """Please provide a detailed and accurate answer to the given question based on the provided contexts.
95
- Ensure that the answer is comprehensive and directly addresses the query.
96
- If necessary, include relevant examples or details from the text.
97
- Question: {question}
98
- Contexts: {contexts}"""
 
 
99
  )
100
- output_parser = StrOutputParser()
101
- chain = prompt | llm | output_parser
102
-
103
- loader = PyMuPDFLoader(pdf_file_path)
104
  docs = loader.load()
105
  full_text = "\n".join(doc.page_content for doc in docs)
106
  cleaned_full_text = clean_text(remove_references(full_text))
 
107
  text_splitter = SpacyTextSplitter(chunk_size=500)
108
-
109
- #text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=0, separators=["\n\n", "\n", ".", " "])
110
  split_contents = text_splitter.split_text(cleaned_full_text)
111
- embeddings = embeddings_model.embed_documents(split_contents)
112
-
113
  query_embedding = embeddings_model.embed_query(query)
114
- similarity_scores = cosine_similarity([query_embedding], embeddings)[0]
115
- top_indices = np.argsort(similarity_scores)[-num_clusters:]
116
- relevant_contents = [split_contents[i] for i in top_indices]
117
-
118
- results = chain.invoke({"question": query, "contexts": ' '.join(relevant_contents)})
119
-
120
- return generate_citations(results, relevant_contents, similarity_threshold)
121
-
122
- def generate_citations(text, contents, similarity_threshold=0.6):
123
- embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small", api_key=openai_api_key)
124
- text_sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
125
- text_embeddings = embeddings_model.embed_documents(text_sentences)
126
- content_embeddings = embeddings_model.embed_documents(contents)
127
- similarity_matrix = cosine_similarity(text_embeddings, content_embeddings)
128
-
129
- cited_text = text
130
- relevant_sources = []
131
- source_mapping = {}
132
- sentence_to_source = {}
133
-
134
- for i, sentence in enumerate(text_sentences):
135
- if sentence in sentence_to_source:
136
- continue
137
- max_similarity = max(similarity_matrix[i])
138
- if max_similarity >= similarity_threshold:
139
- most_similar_idx = np.argmax(similarity_matrix[i])
140
- if most_similar_idx not in source_mapping:
141
- source_mapping[most_similar_idx] = len(relevant_sources) + 1
142
- relevant_sources.append((most_similar_idx, contents[most_similar_idx]))
143
- citation_idx = source_mapping[most_similar_idx]
144
- citation = f"([Source {citation_idx}](#source-{citation_idx}))"
145
- cited_sentence = re.sub(r'([.!?])$', f" {citation}\\1", sentence)
146
- sentence_to_source[sentence] = citation_idx
147
- cited_text = cited_text.replace(sentence, cited_sentence)
148
-
149
- sources_list = "\n\n## Sources:\n"
150
- for idx, (original_idx, content) in enumerate(relevant_sources):
151
- sources_list += f"""
152
- <details style="margin: 1px 0; padding: 5px; border: 1px solid #ccc; border-radius: 8px; background-color: #f9f9f9; transition: all 0.3s ease;">
153
- <summary style="font-weight: bold; cursor: pointer; outline: none; padding: 5px 0; transition: color 0.3s ease;">Source {idx + 1}</summary>
154
- <pre style="white-space: pre-wrap; word-wrap: break-word; margin: 1px 0; padding: 10px; background-color: #fff; border-radius: 5px; border: 1px solid #ddd; box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);">{content}</pre>
155
- </details>
156
- """
157
-
158
- # Add dummy blanks after the last source
159
- dummy_blanks = """
160
- <div style="margin: 20px 0;"></div>
161
- <div style="margin: 20px 0;"></div>
162
- <div style="margin: 20px 0;"></div>
163
- <div style="margin: 20px 0;"></div>
164
- <div style="margin: 20px 0;"></div>
165
- """
166
-
167
- cited_text += sources_list + dummy_blanks
168
- return cited_text
169
-
170
- def infer_image_and_get_boxes(image, confidence_threshold=0.8):
171
- results = model.predict(image)
172
- return [
173
- (int(box.xyxy[0][0]), int(box.xyxy[0][1]), int(box.xyxy[0][2]), int(box.xyxy[0][3]), int(box.cls[0]))
174
- for result in results for box in result.boxes
175
- if int(box.cls[0]) in {figure_class_index, table_class_index} and box.conf[0] > confidence_threshold
176
- ]
177
-
178
- def crop_images_from_boxes(image, boxes, scale_factor):
179
- figures = []
180
- tables = []
181
- for (x1, y1, x2, y2, cls) in boxes:
182
- cropped_img = image[int(y1 * scale_factor):int(y2 * scale_factor), int(x1 * scale_factor):int(x2 * scale_factor)]
183
- if cls == figure_class_index:
184
- figures.append(cropped_img)
185
- elif cls == table_class_index:
186
- tables.append(cropped_img)
187
- return figures, tables
188
-
189
- def process_pdf(pdf_file_path):
190
- doc = fitz.open(pdf_file_path)
191
- all_figures = []
192
- all_tables = []
193
- low_dpi = 50
194
- high_dpi = 300
195
- scale_factor = high_dpi / low_dpi
196
- low_res_pixmaps = [page.get_pixmap(dpi=low_dpi) for page in doc]
197
 
198
- for page_num, low_res_pix in enumerate(low_res_pixmaps):
199
- low_res_img = np.frombuffer(low_res_pix.samples, dtype=np.uint8).reshape(low_res_pix.height, low_res_pix.width, 3)
200
- boxes = infer_image_and_get_boxes(low_res_img)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
  if boxes:
203
- high_res_pix = doc[page_num].get_pixmap(dpi=high_dpi)
204
- high_res_img = np.frombuffer(high_res_pix.samples, dtype=np.uint8).reshape(high_res_pix.height, high_res_pix.width, 3)
205
- figures, tables = crop_images_from_boxes(high_res_img, boxes, scale_factor)
206
- all_figures.extend(figures)
207
- all_tables.extend(tables)
 
 
 
 
 
208
 
209
  return all_figures, all_tables
210
 
211
  def image_to_base64(img):
212
  buffered = io.BytesIO()
213
- img = Image.fromarray(img)
214
- img.save(buffered, format="PNG")
 
215
  return base64.b64encode(buffered.getvalue()).decode()
216
 
217
- def on_btn_click():
218
- del st.session_state.chat_history[:]
 
 
 
 
 
219
 
220
- # Streamlit interface
221
-
222
- # Custom CSS for the file uploader
223
- uploadercss='''
224
- <style>
225
- [data-testid='stFileUploader'] {
226
- width: max-content;
227
- }
228
- [data-testid='stFileUploader'] section {
229
- padding: 0;
230
- float: left;
231
- }
232
- [data-testid='stFileUploader'] section > input + div {
233
- display: none;
234
- }
235
- [data-testid='stFileUploader'] section + div {
236
- float: right;
237
- padding-top: 0;
238
- }
239
- </style>
240
- '''
241
-
242
- st.set_page_config(page_title="PDF Reading Assistant", page_icon="📄")
243
-
244
- # Initialize chat history in session state if not already present
245
  if 'chat_history' not in st.session_state:
246
  st.session_state.chat_history = []
 
 
247
 
248
- st.title("📄 PDF Reading Assistant")
249
- st.markdown("### Extract tables, figures, summaries, and answers from your PDF files easily.")
250
- chat_placeholder = st.empty()
 
 
 
 
 
 
 
 
 
251
 
252
- # File uploader for PDF
253
- uploaded_file = st.file_uploader("Upload a PDF", type="pdf")
254
- st.markdown(uploadercss, unsafe_allow_html=True)
255
- if uploaded_file:
256
- file_path = save_uploaded_file(uploaded_file)
 
 
 
 
 
257
 
258
- # Chat container where all messages will be displayed
 
 
 
 
259
  chat_container = st.container()
260
- user_input = st.chat_input("Ask a question about the pdf......", key="user_input")
261
  with chat_container:
262
- # Scrollable chat messages
263
  for idx, chat in enumerate(st.session_state.chat_history):
 
264
  if chat.get("user"):
265
- message(chat["user"], is_user=True, allow_html=True, key=f"user_{idx}", avatar_style="initials", seed="user")
 
266
  if chat.get("bot"):
267
- message(chat["bot"], is_user=False, allow_html=True, key=f"bot_{idx}",seed="bot")
268
-
269
- # Input area and buttons for user interaction
270
- with st.form(key="chat_form", clear_on_submit=True,border=False):
271
-
272
- col1, col2, col3 = st.columns([1, 1, 1])
273
- with col1:
274
- summary_button = st.form_submit_button("Generate Summary")
275
- with col2:
276
- extract_button = st.form_submit_button("Extract Tables and Figures")
277
- with col3:
278
- st.form_submit_button("Clear message", on_click=on_btn_click)
279
-
280
- # Handle responses based on user input and button presses
281
- if summary_button:
282
- with st.spinner("Generating summary..."):
283
  summary = summarize_pdf(file_path)
284
- st.session_state.chat_history.append({"user": "Generate Summary", "bot": summary})
285
- st.rerun()
286
-
287
- if extract_button:
288
- with st.spinner("Extracting tables and figures..."):
 
 
 
 
289
  figures, tables = process_pdf(file_path)
290
  if figures:
291
- st.session_state.chat_history.append({"user": "Figures"})
292
-
293
- for idx, figure in enumerate(figures):
294
- figure_base64 = image_to_base64(figure)
295
- result_html = f'<img src="data:image/png;base64,{figure_base64}" style="width:100%; display:block;" alt="Figure {idx+1}"/>'
296
- st.session_state.chat_history.append({"bot": f"Figure {idx+1} {result_html}"})
 
297
  if tables:
298
- st.session_state.chat_history.append({"user": "Tables"})
299
- for idx, table in enumerate(tables):
300
- table_base64 = image_to_base64(table)
301
- result_html = f'<img src="data:image/png;base64,{table_base64}" style="width:100%; display:block;" alt="Table {idx+1}"/>'
302
- st.session_state.chat_history.append({"bot": f"Table {idx+1} {result_html}"})
303
- st.rerun()
 
 
 
 
 
 
 
 
 
 
304
 
305
- if user_input:
306
- st.session_state.chat_history.append({"user": user_input, "bot": None})
307
- with st.spinner("Processing..."):
308
- answer = qa_pdf(file_path, user_input)
309
- st.session_state.chat_history[-1]["bot"] = answer
310
- st.rerun()
311
-
312
- # Additional CSS and JavaScript to ensure the chat container is scrollable and scrolls to the bottom
313
  st.markdown("""
314
- <style>
315
- #chat-container {
316
- max-height: 500px;
317
- overflow-y: auto;
318
- padding: 1rem;
319
- border: 1px solid #ddd;
320
- border-radius: 8px;
321
- background-color: #fefefe;
322
- box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
323
- transition: background-color 0.3s ease;
324
- }
325
- #chat-container:hover {
326
- background-color: #f9f9f9;
327
- }
328
- .stChatMessage {
329
- padding: 0.75rem;
330
- margin: 0.75rem 0;
331
- border-radius: 8px;
332
- box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1);
333
- transition: background-color 0.3s ease;
334
- }
335
- .stChatMessage--user {
336
- background-color: #E3F2FD;
337
- }
338
- .stChatMessage--user:hover {
339
- background-color: #BBDEFB;
340
- }
341
- .stChatMessage--bot {
342
- background-color: #EDE7F6;
343
- }
344
- .stChatMessage--bot:hover {
345
- background-color: #D1C4E9;
346
- }
347
- textarea {
348
- width: 100%;
349
- padding: 1rem;
350
- border: 1px solid #ddd;
351
- border-radius: 8px;
352
- box-shadow: inset 0 1px 3px rgba(0, 0, 0, 0.1);
353
- transition: border-color 0.3s ease, box-shadow 0.3s ease;
354
- }
355
- textarea:focus {
356
- border-color: #4CAF50;
357
- box-shadow: 0 0 5px rgba(76, 175, 80, 0.5);
358
- }
359
- .stButton > button {
360
- width: 100%;
361
- background-color: #4CAF50;
362
- color: white;
363
- border: none;
364
- border-radius: 8px;
365
- padding: 0.75rem;
366
- font-size: 16px;
367
- box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
368
- transition: background-color 0.3s ease, box-shadow 0.3s ease;
369
- }
370
- .stButton > button:hover {
371
- background-color: #45A049;
372
- box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
373
- }
374
- </style>
375
- <script>
376
- const chatContainer = document.getElementById('chat-container');
377
- chatContainer.scrollTop = chatContainer.scrollHeight;
378
- </script>
379
- """, unsafe_allow_html=True)
 
1
 
2
 
 
 
3
  import os
4
+ import time
5
  import io
6
  import base64
7
+ import re
8
  import numpy as np
9
  import fitz # PyMuPDF
10
  import tempfile
11
+ from PIL import Image
12
  from sklearn.cluster import KMeans
13
  from sklearn.metrics.pairwise import cosine_similarity
14
+ from ultralytics import YOLO
15
+ import streamlit as st
16
+ from streamlit_chat import message
17
  from langchain_core.output_parsers import StrOutputParser
18
  from langchain_community.document_loaders import PyMuPDFLoader
19
+ from langchain_openai import OpenAIEmbeddings, ChatOpenAI
 
20
  from langchain_text_splitters import SpacyTextSplitter
21
  from langchain_core.prompts import ChatPromptTemplate
22
+ from streamlit.runtime.scriptrunner import get_script_run_ctx
23
+ from streamlit import runtime
 
 
 
 
24
 
25
+ # Initialize models and environment
26
+ os.system("python -m spacy download en_core_web_sm")
27
  model = YOLO("best.pt")
28
  openai_api_key = os.environ.get("openai_api_key")
29
+ MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
 
 
 
30
 
31
  # Utility functions
32
+ @st.cache_data(show_spinner=False, ttl=3600)
33
  def clean_text(text):
34
  return re.sub(r'\s+', ' ', text).strip()
35
 
36
  def remove_references(text):
37
  reference_patterns = [
38
+ r'\bReferences\b', r'\breferences\b', r'\bBibliography\b',
39
+ r'\bCitations\b', r'\bWorks Cited\b', r'\bReference\b'
40
  ]
41
  lines = text.split('\n')
42
  for i, line in enumerate(lines):
 
44
  return '\n'.join(lines[:i])
45
  return text
46
 
47
+ def handle_errors(func):
48
+ def wrapper(*args, **kwargs):
49
+ try:
50
+ return func(*args, **kwargs)
51
+ except Exception as e:
52
+ st.session_state.chat_history.append({
53
+ "bot": f"❌ An error occurred: {str(e)}"
54
+ })
55
+ st.rerun()
56
+ return wrapper
57
+
58
+ def show_progress(message):
59
+ progress_bar = st.progress(0)
60
+ status_text = st.empty()
61
+ for i in range(100):
62
+ time.sleep(0.02)
63
+ progress_bar.progress(i + 1)
64
+ status_text.text(f"{message}... {i+1}%")
65
+ progress_bar.empty()
66
+ status_text.empty()
67
+
68
+ def scroll_to_bottom():
69
+ ctx = get_script_run_ctx()
70
+ if ctx and runtime.exists():
71
+ js = """
72
+ <script>
73
+ function scrollToBottom() {
74
+ window.parent.document.querySelector('section.main').scrollTo(0, window.parent.document.querySelector('section.main').scrollHeight);
75
+ }
76
+ setTimeout(scrollToBottom, 100);
77
+ </script>
78
+ """
79
+ st.components.v1.html(js, height=0)
80
+
81
+ # Core processing functions
82
+ @st.cache_data(show_spinner=False, ttl=3600)
83
+ @handle_errors
84
+ def summarize_pdf(_pdf_file_path, num_clusters=10):
85
  embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small", api_key=openai_api_key)
86
+ llm = ChatOpenAI(model="gpt-4", api_key=openai_api_key, temperature=0.3)
87
+
88
  prompt = ChatPromptTemplate.from_template(
89
+ """Generate a comprehensive summary with these elements:
90
+ 1. Key findings and conclusions
91
+ 2. Main methodologies used
92
+ 3. Important data points
93
+ 4. Limitations mentioned
94
+ Context: {topic}"""
 
 
95
  )
96
+
97
+ loader = PyMuPDFLoader(_pdf_file_path)
 
 
98
  docs = loader.load()
99
  full_text = "\n".join(doc.page_content for doc in docs)
100
  cleaned_full_text = clean_text(remove_references(full_text))
101
+
102
  text_splitter = SpacyTextSplitter(chunk_size=500)
 
103
  split_contents = text_splitter.split_text(cleaned_full_text)
104
+
105
  embeddings = embeddings_model.embed_documents(split_contents)
106
+ kmeans = KMeans(n_clusters=num_clusters, random_state=0).fit(embeddings)
107
+ closest_indices = [np.argmin(np.linalg.norm(embeddings - center, axis=1))
108
+ for center in kmeans.cluster_centers_]
109
+
110
+ chain = prompt | llm | StrOutputParser()
111
+ return chain.invoke({"topic": ' '.join([split_contents[idx] for idx in closest_indices])})
112
 
113
+ @st.cache_data(show_spinner=False, ttl=3600)
114
+ @handle_errors
115
+ def qa_pdf(_pdf_file_path, query, num_clusters=5):
 
 
 
 
 
 
116
  embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small", api_key=openai_api_key)
117
+ llm = ChatOpenAI(model="gpt-4", api_key=openai_api_key, temperature=0.3)
118
+
119
  prompt = ChatPromptTemplate.from_template(
120
+ """Answer this question: {question}
121
+ Using only this context: {context}
122
+ Format your answer with:
123
+ - Clear section headings
124
+ - Bullet points for lists
125
+ - Bold key terms
126
+ - Citations from the text"""
127
  )
128
+
129
+ loader = PyMuPDFLoader(_pdf_file_path)
 
 
130
  docs = loader.load()
131
  full_text = "\n".join(doc.page_content for doc in docs)
132
  cleaned_full_text = clean_text(remove_references(full_text))
133
+
134
  text_splitter = SpacyTextSplitter(chunk_size=500)
 
 
135
  split_contents = text_splitter.split_text(cleaned_full_text)
136
+
 
137
  query_embedding = embeddings_model.embed_query(query)
138
+ similarities = cosine_similarity([query_embedding],
139
+ embeddings_model.embed_documents(split_contents))[0]
140
+ top_indices = np.argsort(similarities)[-num_clusters:]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
+ chain = prompt | llm | StrOutputParser()
143
+ return chain.invoke({
144
+ "question": query,
145
+ "context": ' '.join([split_contents[i] for i in top_indices])
146
+ })
147
+
148
+ @st.cache_data(show_spinner=False, ttl=3600)
149
+ @handle_errors
150
+ def process_pdf(_pdf_file_path):
151
+ doc = fitz.open(_pdf_file_path)
152
+ all_figures, all_tables = [], []
153
+ scale_factor = 300 / 50 # High-res to low-res ratio
154
+
155
+ for page in doc:
156
+ low_res = page.get_pixmap(dpi=50)
157
+ low_res_img = np.frombuffer(low_res.samples, dtype=np.uint8).reshape(low_res.height, low_res.width, 3)
158
+
159
+ results = model.predict(low_res_img)
160
+ boxes = [
161
+ (int(box.xyxy[0][0]), int(box.xyxy[0][1]),
162
+ int(box.xyxy[0][2]), int(box.xyxy[0][3]), int(box.cls[0]))
163
+ for result in results for box in result.boxes
164
+ if box.conf[0] > 0.8 and int(box.cls[0]) in {3, 4}
165
+ ]
166
 
167
  if boxes:
168
+ high_res = page.get_pixmap(dpi=300)
169
+ high_res_img = np.frombuffer(high_res.samples, dtype=np.uint8).reshape(high_res.height, high_res.width, 3)
170
+
171
+ for (x1, y1, x2, y2, cls) in boxes:
172
+ cropped = high_res_img[int(y1*scale_factor):int(y2*scale_factor),
173
+ int(x1*scale_factor):int(x2*scale_factor)]
174
+ if cls == 4:
175
+ all_figures.append(cropped)
176
+ else:
177
+ all_tables.append(cropped)
178
 
179
  return all_figures, all_tables
180
 
181
  def image_to_base64(img):
182
  buffered = io.BytesIO()
183
+ img = Image.fromarray(img).convert("RGB")
184
+ img.thumbnail((800, 800)) # Optimize image size
185
+ img.save(buffered, format="JPEG", quality=85)
186
  return base64.b64encode(buffered.getvalue()).decode()
187
 
188
+ # Streamlit UI
189
+ st.set_page_config(
190
+ page_title="PDF Assistant",
191
+ page_icon="📄",
192
+ layout="wide",
193
+ initial_sidebar_state="expanded"
194
+ )
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  if 'chat_history' not in st.session_state:
197
  st.session_state.chat_history = []
198
+ if 'current_file' not in st.session_state:
199
+ st.session_state.current_file = None
200
 
201
+ st.title("📄 Smart PDF Analyzer")
202
+ st.markdown("""
203
+ <div style="border-left: 4px solid #4CAF50; padding-left: 1rem; margin: 1rem 0;">
204
+ <p style="color: #666; font-size: 0.95rem;">✨ Upload a PDF to:
205
+ <ul style="color: #666; font-size: 0.95rem;">
206
+ <li>Generate structured summaries</li>
207
+ <li>Extract visual content</li>
208
+ <li>Ask contextual questions</li>
209
+ </ul>
210
+ </p>
211
+ </div>
212
+ """, unsafe_allow_html=True)
213
 
214
+ uploaded_file = st.file_uploader(
215
+ "Choose PDF file",
216
+ type="pdf",
217
+ help="Max file size: 50MB",
218
+ on_change=lambda: setattr(st.session_state, 'chat_history', [])
219
+ )
220
+
221
+ if uploaded_file and uploaded_file.size > MAX_FILE_SIZE:
222
+ st.error("File size exceeds 50MB limit")
223
+ st.stop()
224
 
225
+ if uploaded_file:
226
+ file_path = tempfile.NamedTemporaryFile(delete=False).name
227
+ with open(file_path, "wb") as f:
228
+ f.write(uploaded_file.getbuffer())
229
+
230
  chat_container = st.container()
 
231
  with chat_container:
 
232
  for idx, chat in enumerate(st.session_state.chat_history):
233
+ col1, col2 = st.columns([1, 4])
234
  if chat.get("user"):
235
+ with col2:
236
+ message(chat["user"], is_user=True, key=f"user_{idx}")
237
  if chat.get("bot"):
238
+ with col1:
239
+ message(chat["bot"], key=f"bot_{idx}", allow_html=True)
240
+ scroll_to_bottom()
241
+
242
+ with st.container():
243
+ col1, col2, col3 = st.columns([3, 2, 2])
244
+ with col1:
245
+ user_input = st.chat_input("Ask about the document...")
246
+ with col2:
247
+ if st.button("📝 Generate Summary", use_container_width=True):
248
+ with st.spinner("Analyzing document structure..."):
249
+ show_progress("Generating summary")
 
 
 
 
250
  summary = summarize_pdf(file_path)
251
+ st.session_state.chat_history.append({
252
+ "user": "Summary request",
253
+ "bot": f"## Document Summary\n{summary}"
254
+ })
255
+ st.rerun()
256
+ with col3:
257
+ if st.button("🖼️ Extract Visuals", use_container_width=True):
258
+ with st.spinner("Identifying figures and tables..."):
259
+ show_progress("Extracting visuals")
260
  figures, tables = process_pdf(file_path)
261
  if figures:
262
+ st.session_state.chat_history.append({
263
+ "bot": f"Found {len(figures)} figures:"
264
+ })
265
+ for fig in figures:
266
+ st.session_state.chat_history.append({
267
+ "bot": f'<img src="data:image/jpeg;base64,{image_to_base64(fig)}" style="max-width: 100%;">'
268
+ })
269
  if tables:
270
+ st.session_state.chat_history.append({
271
+ "bot": f"Found {len(tables)} tables:"
272
+ })
273
+ for tab in tables:
274
+ st.session_state.chat_history.append({
275
+ "bot": f'<img src="data:image/jpeg;base64,{image_to_base64(tab)}" style="max-width: 100%;">'
276
+ })
277
+ st.rerun()
278
+
279
+ if user_input:
280
+ st.session_state.chat_history.append({"user": user_input})
281
+ with st.spinner("Analyzing query..."):
282
+ show_progress("Generating answer")
283
+ answer = qa_pdf(file_path, user_input)
284
+ st.session_state.chat_history[-1]["bot"] = f"## Answer\n{answer}"
285
+ st.rerun()
286
 
 
 
 
 
 
 
 
 
287
  st.markdown("""
288
+ <style>
289
+ .stChatMessage {
290
+ padding: 1.25rem;
291
+ margin: 1rem 0;
292
+ border-radius: 12px;
293
+ box-shadow: 0 2px 8px rgba(0,0,0,0.1);
294
+ transition: transform 0.2s ease;
295
+ }
296
+ .stChatMessage:hover {
297
+ transform: translateY(-2px);
298
+ }
299
+ .stButton>button {
300
+ background: linear-gradient(45deg, #4CAF50, #45a049);
301
+ color: white;
302
+ border: none;
303
+ border-radius: 8px;
304
+ padding: 12px 24px;
305
+ font-size: 16px;
306
+ transition: all 0.3s ease;
307
+ }
308
+ .stButton>button:hover {
309
+ box-shadow: 0 4px 12px rgba(76,175,80,0.3);
310
+ transform: translateY(-1px);
311
+ }
312
+ [data-testid="stFileUploader"] {
313
+ border: 2px dashed #4CAF50;
314
+ border-radius: 12px;
315
+ padding: 2rem;
316
+ }
317
+ </style>
318
+ """, unsafe_allow_html=True)