File size: 7,803 Bytes
8bbef17
beca6a7
8705301
 
8bbef17
 
 
8705301
8bbef17
 
 
 
 
 
 
 
 
 
8705301
8bbef17
4522002
89f2ae3
 
 
 
 
533d3ec
 
89f2ae3
6e1a1ed
89f2ae3
4522002
8bbef17
89f2ae3
 
 
 
 
 
8bbef17
8705301
8bbef17
8705301
8bbef17
 
 
89f2ae3
 
8bbef17
89f2ae3
8bbef17
89f2ae3
 
 
 
8bbef17
 
89f2ae3
 
 
 
 
 
8bbef17
89f2ae3
 
 
 
 
 
 
8bbef17
89f2ae3
 
 
 
8bbef17
 
 
89f2ae3
 
8bbef17
89f2ae3
 
 
 
 
 
 
 
8bbef17
 
89f2ae3
8bbef17
89f2ae3
 
 
 
 
 
 
8bbef17
89f2ae3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7263d31
89f2ae3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8705301
89f2ae3
 
8705301
89f2ae3
 
 
8705301
89f2ae3
 
 
 
 
8705301
89f2ae3
 
8705301
 
89f2ae3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
import os
os.system("python -m spacy download en_core_web_sm")
import io
import base64
import streamlit as st
import numpy as np
import fitz  # PyMuPDF
import tempfile
from ultralytics import YOLO
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import cosine_similarity
from langchain_core.output_parsers import StrOutputParser
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
import re
from PIL import Image

openai_api_key = os.environ.get("openai_api_key")
# Cached resources
@st.cache_resource
def load_models():
    return {
        "yolo": YOLO("best.pt"),
        "embeddings": OpenAIEmbeddings(model="text-embedding-3-small",api_key=openai_api_key),
        "llm": ChatOpenAI(model="gpt-4-turbo", temperature=0.3,api_key=openai_api_key)
    }

models = load_models()


# Constants
FIGURE_CLASS_INDEX = 4
TABLE_CLASS_INDEX = 3
CHUNK_SIZE = 1000
CHUNK_OVERLAP = 200
NUM_CLUSTERS = 8

# Utility functions
def clean_text(text):
    return re.sub(r'\s+', ' ', text).strip()

def remove_references(text):
    reference_patterns = [
        r'\bReferences\b', r'\breferences\b', r'\bBibliography\b', 
        r'\bCitations\b', r'\bWorks Cited\b'
    ]
    return re.sub('|'.join(reference_patterns), '', text, flags=re.IGNORECASE)

@st.cache_data
def process_pdf(file_path):
    """Process PDF once and cache results"""
    loader = PyMuPDFLoader(file_path)
    docs = loader.load()
    full_text = "\n".join(doc.page_content for doc in docs)
    cleaned_text = clean_text(remove_references(full_text))
    
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=CHUNK_SIZE,
        chunk_overlap=CHUNK_OVERLAP,
        separators=["\n\n", "\n", ". ", "! ", "? ", " "]
    )
    split_contents = text_splitter.split_text(cleaned_text)
    
    return {
        "text": cleaned_text,
        "chunks": split_contents,
        "embeddings": models["embeddings"].embed_documents(split_contents)
    }

@st.cache_data
def extract_visuals(file_path):
    """Extract figures and tables with caching"""
    doc = fitz.open(file_path)
    all_figures = []
    all_tables = []
    
    for page in doc:
        low_res_pix = page.get_pixmap(dpi=50)
        low_res_img = np.frombuffer(low_res_pix.samples, dtype=np.uint8).reshape(low_res_pix.height, low_res_pix.width, 3)
        
        results = models["yolo"].predict(low_res_img)
        boxes = [
            (int(box.xyxy[0][0]), int(box.xyxy[0][1]), 
             int(box.xyxy[0][2]), int(box.xyxy[0][3]), int(box.cls[0]))
            for result in results for box in result.boxes
            if box.conf[0] > 0.8 and int(box.cls[0]) in {FIGURE_CLASS_INDEX, TABLE_CLASS_INDEX}
        ]
        
        if boxes:
            high_res_pix = page.get_pixmap(dpi=300)
            high_res_img = np.frombuffer(high_res_pix.samples, dtype=np.uint8).reshape(high_res_pix.height, high_res_pix.width, 3)
            
            for x1, y1, x2, y2, cls in boxes:
                img = high_res_img[int(y1*6):int(y2*6), int(x1*6):int(x2*6)]
                if cls == FIGURE_CLASS_INDEX:
                    all_figures.append(img)
                else:
                    all_tables.append(img)
    
    return {"figures": all_figures, "tables": all_tables}

def generate_summary(chunks, embeddings):
    """Generate summary using clustered chunks"""
    kmeans = KMeans(n_clusters=NUM_CLUSTERS, init='k-means++').fit(embeddings)
    cluster_indices = [np.argmin(np.linalg.norm(embeddings - center, axis=1)) 
                       for center in kmeans.cluster_centers_]
    selected_chunks = [chunks[i] for i in cluster_indices]
    
    prompt = ChatPromptTemplate.from_template(
        """Create a structured summary with key points from these context sections:
        {contexts}
        Format:
        ## Summary
        [concise overview]
        ## Key Points
        - [main point 1]
        - [main point 2]
        ..."""
    )
    chain = prompt | models["llm"] | StrOutputParser()
    return chain.invoke({"contexts": '\n\n'.join(selected_chunks)})

def answer_question(question, chunks, embeddings):
    """Answer question using semantic search"""
    query_embedding = models["embeddings"].embed_query(question)
    similarities = cosine_similarity([query_embedding], embeddings)[0]
    top_indices = np.argsort(similarities)[-5:][::-1]
    context = '\n'.join([chunks[i] for i in top_indices if similarities[i] > 0.6])
    
    prompt = ChatPromptTemplate.from_template(
        """Answer this question: {question}
        Using only this context: {context}
        - Be precise and include relevant details
        - Cite sources as [Source 1], [Source 2], etc."""
    )
    chain = prompt | models["llm"] | StrOutputParser()
    return chain.invoke({"question": question, "context": context})

# Streamlit UI
#st.set_page_config(page_title="PDF Assistant", layout="wide")
st.title("πŸ“„ Smart PDF Assistant")

if "chat" not in st.session_state:
    st.session_state.chat = []
if "processed_data" not in st.session_state:
    st.session_state.processed_data = None

# File upload section
with st.sidebar:
    uploaded_file = st.file_uploader("Upload PDF", type="pdf")
    if uploaded_file:
        with tempfile.NamedTemporaryFile(delete=False) as tmp:
            tmp.write(uploaded_file.getbuffer())
            st.session_state.processed_data = process_pdf(tmp.name)
            visuals = extract_visuals(tmp.name)

# Chat interface
col1, col2 = st.columns([3, 1])
with col1:
    st.subheader("Document Interaction")
    
    for msg in st.session_state.chat:
        with st.chat_message(msg["role"]):
            if "image" in msg:
                st.image(msg["image"], caption=msg.get("caption"))
            else:
                st.markdown(msg["content"])

    if prompt := st.chat_input("Ask about the document..."):
        st.session_state.chat.append({"role": "user", "content": prompt})
        with st.spinner("Analyzing..."):
            response = answer_question(
                prompt,
                st.session_state.processed_data["chunks"],
                st.session_state.processed_data["embeddings"]
            )
        st.session_state.chat.append({"role": "assistant", "content": response})
        st.rerun()

with col2:
    st.subheader("Document Insights")
    
    if st.button("Generate Summary"):
        with st.spinner("Summarizing..."):
            summary = generate_summary(
                st.session_state.processed_data["chunks"],
                st.session_state.processed_data["embeddings"]
            )
            st.session_state.chat.append({
                "role": "assistant",
                "content": f"## Document Summary\n{summary}"
            })
        st.rerun()
    
    if visuals["figures"]:
        with st.expander(f"πŸ“· Figures ({len(visuals['figures'])})"):
            for idx, fig in enumerate(visuals["figures"], 1):
                st.image(fig, caption=f"Figure {idx}")
    
    if visuals["tables"]:
        with st.expander(f"πŸ“Š Tables ({len(visuals['tables'])})"):
            for idx, tbl in enumerate(visuals["tables"], 1):
                st.image(tbl, caption=f"Table {idx}")

# Custom styling
st.markdown("""
<style>
    [data-testid=stSidebar] {
        background: #fafafa;
        border-right: 1px solid #eee;
    }
    .stChatMessage {
        padding: 1rem;
        margin: 0.5rem 0;
        border-radius: 10px;
        box-shadow: 0 2px 5px rgba(0,0,0,0.1);
    }
    [data-testid=stVerticalBlock] > div:has(>.stChatMessage) {
        gap: 0.5rem;
    }
</style>
""", unsafe_allow_html=True)