|
import os |
|
os.system("python -m spacy download en_core_web_sm") |
|
import io |
|
import base64 |
|
import streamlit as st |
|
import numpy as np |
|
import fitz |
|
import tempfile |
|
from ultralytics import YOLO |
|
from sklearn.cluster import KMeans |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_community.document_loaders import PyMuPDFLoader |
|
from langchain_openai import OpenAIEmbeddings |
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain_openai import ChatOpenAI |
|
import re |
|
from PIL import Image |
|
|
|
openai_api_key = os.environ.get("openai_api_key") |
|
|
|
@st.cache_resource |
|
def load_models(): |
|
return { |
|
"yolo": YOLO("best.pt"), |
|
"embeddings": OpenAIEmbeddings(model="text-embedding-3-small",api_key=openai_api_key), |
|
"llm": ChatOpenAI(model="gpt-4-turbo", temperature=0.3,api_key=openai_api_key) |
|
} |
|
|
|
models = load_models() |
|
|
|
|
|
|
|
FIGURE_CLASS_INDEX = 4 |
|
TABLE_CLASS_INDEX = 3 |
|
CHUNK_SIZE = 1000 |
|
CHUNK_OVERLAP = 200 |
|
NUM_CLUSTERS = 8 |
|
|
|
|
|
def clean_text(text): |
|
return re.sub(r'\s+', ' ', text).strip() |
|
|
|
def remove_references(text): |
|
reference_patterns = [ |
|
r'\bReferences\b', r'\breferences\b', r'\bBibliography\b', |
|
r'\bCitations\b', r'\bWorks Cited\b' |
|
] |
|
return re.sub('|'.join(reference_patterns), '', text, flags=re.IGNORECASE) |
|
|
|
@st.cache_data |
|
def process_pdf(file_path): |
|
"""Process PDF once and cache results""" |
|
loader = PyMuPDFLoader(file_path) |
|
docs = loader.load() |
|
full_text = "\n".join(doc.page_content for doc in docs) |
|
cleaned_text = clean_text(remove_references(full_text)) |
|
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=CHUNK_SIZE, |
|
chunk_overlap=CHUNK_OVERLAP, |
|
separators=["\n\n", "\n", ". ", "! ", "? ", " "] |
|
) |
|
split_contents = text_splitter.split_text(cleaned_text) |
|
|
|
return { |
|
"text": cleaned_text, |
|
"chunks": split_contents, |
|
"embeddings": models["embeddings"].embed_documents(split_contents) |
|
} |
|
|
|
@st.cache_data |
|
def extract_visuals(file_path): |
|
"""Extract figures and tables with caching""" |
|
doc = fitz.open(file_path) |
|
all_figures = [] |
|
all_tables = [] |
|
|
|
for page in doc: |
|
low_res_pix = page.get_pixmap(dpi=50) |
|
low_res_img = np.frombuffer(low_res_pix.samples, dtype=np.uint8).reshape(low_res_pix.height, low_res_pix.width, 3) |
|
|
|
results = models["yolo"].predict(low_res_img) |
|
boxes = [ |
|
(int(box.xyxy[0][0]), int(box.xyxy[0][1]), |
|
int(box.xyxy[0][2]), int(box.xyxy[0][3]), int(box.cls[0])) |
|
for result in results for box in result.boxes |
|
if box.conf[0] > 0.8 and int(box.cls[0]) in {FIGURE_CLASS_INDEX, TABLE_CLASS_INDEX} |
|
] |
|
|
|
if boxes: |
|
high_res_pix = page.get_pixmap(dpi=300) |
|
high_res_img = np.frombuffer(high_res_pix.samples, dtype=np.uint8).reshape(high_res_pix.height, high_res_pix.width, 3) |
|
|
|
for x1, y1, x2, y2, cls in boxes: |
|
img = high_res_img[int(y1*6):int(y2*6), int(x1*6):int(x2*6)] |
|
if cls == FIGURE_CLASS_INDEX: |
|
all_figures.append(img) |
|
else: |
|
all_tables.append(img) |
|
|
|
return {"figures": all_figures, "tables": all_tables} |
|
|
|
def generate_summary(chunks, embeddings): |
|
"""Generate summary using clustered chunks""" |
|
kmeans = KMeans(n_clusters=NUM_CLUSTERS, init='k-means++').fit(embeddings) |
|
cluster_indices = [np.argmin(np.linalg.norm(embeddings - center, axis=1)) |
|
for center in kmeans.cluster_centers_] |
|
selected_chunks = [chunks[i] for i in cluster_indices] |
|
|
|
prompt = ChatPromptTemplate.from_template( |
|
"""Create a structured summary with key points from these context sections: |
|
{contexts} |
|
Format: |
|
## Summary |
|
[concise overview] |
|
## Key Points |
|
- [main point 1] |
|
- [main point 2] |
|
...""" |
|
) |
|
chain = prompt | models["llm"] | StrOutputParser() |
|
return chain.invoke({"contexts": '\n\n'.join(selected_chunks)}) |
|
|
|
def answer_question(question, chunks, embeddings): |
|
"""Answer question using semantic search""" |
|
query_embedding = models["embeddings"].embed_query(question) |
|
similarities = cosine_similarity([query_embedding], embeddings)[0] |
|
top_indices = np.argsort(similarities)[-5:][::-1] |
|
context = '\n'.join([chunks[i] for i in top_indices if similarities[i] > 0.6]) |
|
|
|
prompt = ChatPromptTemplate.from_template( |
|
"""Answer this question: {question} |
|
Using only this context: {context} |
|
- Be precise and include relevant details |
|
- Cite sources as [Source 1], [Source 2], etc.""" |
|
) |
|
chain = prompt | models["llm"] | StrOutputParser() |
|
return chain.invoke({"question": question, "context": context}) |
|
|
|
|
|
|
|
st.title("π Smart PDF Assistant") |
|
|
|
if "chat" not in st.session_state: |
|
st.session_state.chat = [] |
|
if "processed_data" not in st.session_state: |
|
st.session_state.processed_data = None |
|
|
|
|
|
with st.sidebar: |
|
uploaded_file = st.file_uploader("Upload PDF", type="pdf") |
|
if uploaded_file: |
|
with tempfile.NamedTemporaryFile(delete=False) as tmp: |
|
tmp.write(uploaded_file.getbuffer()) |
|
st.session_state.processed_data = process_pdf(tmp.name) |
|
visuals = extract_visuals(tmp.name) |
|
|
|
|
|
col1, col2 = st.columns([3, 1]) |
|
with col1: |
|
st.subheader("Document Interaction") |
|
|
|
for msg in st.session_state.chat: |
|
with st.chat_message(msg["role"]): |
|
if "image" in msg: |
|
st.image(msg["image"], caption=msg.get("caption")) |
|
else: |
|
st.markdown(msg["content"]) |
|
|
|
if prompt := st.chat_input("Ask about the document..."): |
|
st.session_state.chat.append({"role": "user", "content": prompt}) |
|
with st.spinner("Analyzing..."): |
|
response = answer_question( |
|
prompt, |
|
st.session_state.processed_data["chunks"], |
|
st.session_state.processed_data["embeddings"] |
|
) |
|
st.session_state.chat.append({"role": "assistant", "content": response}) |
|
st.rerun() |
|
|
|
with col2: |
|
st.subheader("Document Insights") |
|
|
|
if st.button("Generate Summary"): |
|
with st.spinner("Summarizing..."): |
|
summary = generate_summary( |
|
st.session_state.processed_data["chunks"], |
|
st.session_state.processed_data["embeddings"] |
|
) |
|
st.session_state.chat.append({ |
|
"role": "assistant", |
|
"content": f"## Document Summary\n{summary}" |
|
}) |
|
st.rerun() |
|
|
|
if visuals["figures"]: |
|
with st.expander(f"π· Figures ({len(visuals['figures'])})"): |
|
for idx, fig in enumerate(visuals["figures"], 1): |
|
st.image(fig, caption=f"Figure {idx}") |
|
|
|
if visuals["tables"]: |
|
with st.expander(f"π Tables ({len(visuals['tables'])})"): |
|
for idx, tbl in enumerate(visuals["tables"], 1): |
|
st.image(tbl, caption=f"Table {idx}") |
|
|
|
|
|
st.markdown(""" |
|
<style> |
|
[data-testid=stSidebar] { |
|
background: #fafafa; |
|
border-right: 1px solid #eee; |
|
} |
|
.stChatMessage { |
|
padding: 1rem; |
|
margin: 0.5rem 0; |
|
border-radius: 10px; |
|
box-shadow: 0 2px 5px rgba(0,0,0,0.1); |
|
} |
|
[data-testid=stVerticalBlock] > div:has(>.stChatMessage) { |
|
gap: 0.5rem; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |