zliang's picture
Update app.py
7263d31 verified
raw
history blame
7.8 kB
import os
os.system("python -m spacy download en_core_web_sm")
import io
import base64
import streamlit as st
import numpy as np
import fitz # PyMuPDF
import tempfile
from ultralytics import YOLO
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import cosine_similarity
from langchain_core.output_parsers import StrOutputParser
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
import re
from PIL import Image
openai_api_key = os.environ.get("openai_api_key")
# Cached resources
@st.cache_resource
def load_models():
return {
"yolo": YOLO("best.pt"),
"embeddings": OpenAIEmbeddings(model="text-embedding-3-small",api_key=openai_api_key),
"llm": ChatOpenAI(model="gpt-4-turbo", temperature=0.3,api_key=openai_api_key)
}
models = load_models()
# Constants
FIGURE_CLASS_INDEX = 4
TABLE_CLASS_INDEX = 3
CHUNK_SIZE = 1000
CHUNK_OVERLAP = 200
NUM_CLUSTERS = 8
# Utility functions
def clean_text(text):
return re.sub(r'\s+', ' ', text).strip()
def remove_references(text):
reference_patterns = [
r'\bReferences\b', r'\breferences\b', r'\bBibliography\b',
r'\bCitations\b', r'\bWorks Cited\b'
]
return re.sub('|'.join(reference_patterns), '', text, flags=re.IGNORECASE)
@st.cache_data
def process_pdf(file_path):
"""Process PDF once and cache results"""
loader = PyMuPDFLoader(file_path)
docs = loader.load()
full_text = "\n".join(doc.page_content for doc in docs)
cleaned_text = clean_text(remove_references(full_text))
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=CHUNK_OVERLAP,
separators=["\n\n", "\n", ". ", "! ", "? ", " "]
)
split_contents = text_splitter.split_text(cleaned_text)
return {
"text": cleaned_text,
"chunks": split_contents,
"embeddings": models["embeddings"].embed_documents(split_contents)
}
@st.cache_data
def extract_visuals(file_path):
"""Extract figures and tables with caching"""
doc = fitz.open(file_path)
all_figures = []
all_tables = []
for page in doc:
low_res_pix = page.get_pixmap(dpi=50)
low_res_img = np.frombuffer(low_res_pix.samples, dtype=np.uint8).reshape(low_res_pix.height, low_res_pix.width, 3)
results = models["yolo"].predict(low_res_img)
boxes = [
(int(box.xyxy[0][0]), int(box.xyxy[0][1]),
int(box.xyxy[0][2]), int(box.xyxy[0][3]), int(box.cls[0]))
for result in results for box in result.boxes
if box.conf[0] > 0.8 and int(box.cls[0]) in {FIGURE_CLASS_INDEX, TABLE_CLASS_INDEX}
]
if boxes:
high_res_pix = page.get_pixmap(dpi=300)
high_res_img = np.frombuffer(high_res_pix.samples, dtype=np.uint8).reshape(high_res_pix.height, high_res_pix.width, 3)
for x1, y1, x2, y2, cls in boxes:
img = high_res_img[int(y1*6):int(y2*6), int(x1*6):int(x2*6)]
if cls == FIGURE_CLASS_INDEX:
all_figures.append(img)
else:
all_tables.append(img)
return {"figures": all_figures, "tables": all_tables}
def generate_summary(chunks, embeddings):
"""Generate summary using clustered chunks"""
kmeans = KMeans(n_clusters=NUM_CLUSTERS, init='k-means++').fit(embeddings)
cluster_indices = [np.argmin(np.linalg.norm(embeddings - center, axis=1))
for center in kmeans.cluster_centers_]
selected_chunks = [chunks[i] for i in cluster_indices]
prompt = ChatPromptTemplate.from_template(
"""Create a structured summary with key points from these context sections:
{contexts}
Format:
## Summary
[concise overview]
## Key Points
- [main point 1]
- [main point 2]
..."""
)
chain = prompt | models["llm"] | StrOutputParser()
return chain.invoke({"contexts": '\n\n'.join(selected_chunks)})
def answer_question(question, chunks, embeddings):
"""Answer question using semantic search"""
query_embedding = models["embeddings"].embed_query(question)
similarities = cosine_similarity([query_embedding], embeddings)[0]
top_indices = np.argsort(similarities)[-5:][::-1]
context = '\n'.join([chunks[i] for i in top_indices if similarities[i] > 0.6])
prompt = ChatPromptTemplate.from_template(
"""Answer this question: {question}
Using only this context: {context}
- Be precise and include relevant details
- Cite sources as [Source 1], [Source 2], etc."""
)
chain = prompt | models["llm"] | StrOutputParser()
return chain.invoke({"question": question, "context": context})
# Streamlit UI
#st.set_page_config(page_title="PDF Assistant", layout="wide")
st.title("πŸ“„ Smart PDF Assistant")
if "chat" not in st.session_state:
st.session_state.chat = []
if "processed_data" not in st.session_state:
st.session_state.processed_data = None
# File upload section
with st.sidebar:
uploaded_file = st.file_uploader("Upload PDF", type="pdf")
if uploaded_file:
with tempfile.NamedTemporaryFile(delete=False) as tmp:
tmp.write(uploaded_file.getbuffer())
st.session_state.processed_data = process_pdf(tmp.name)
visuals = extract_visuals(tmp.name)
# Chat interface
col1, col2 = st.columns([3, 1])
with col1:
st.subheader("Document Interaction")
for msg in st.session_state.chat:
with st.chat_message(msg["role"]):
if "image" in msg:
st.image(msg["image"], caption=msg.get("caption"))
else:
st.markdown(msg["content"])
if prompt := st.chat_input("Ask about the document..."):
st.session_state.chat.append({"role": "user", "content": prompt})
with st.spinner("Analyzing..."):
response = answer_question(
prompt,
st.session_state.processed_data["chunks"],
st.session_state.processed_data["embeddings"]
)
st.session_state.chat.append({"role": "assistant", "content": response})
st.rerun()
with col2:
st.subheader("Document Insights")
if st.button("Generate Summary"):
with st.spinner("Summarizing..."):
summary = generate_summary(
st.session_state.processed_data["chunks"],
st.session_state.processed_data["embeddings"]
)
st.session_state.chat.append({
"role": "assistant",
"content": f"## Document Summary\n{summary}"
})
st.rerun()
if visuals["figures"]:
with st.expander(f"πŸ“· Figures ({len(visuals['figures'])})"):
for idx, fig in enumerate(visuals["figures"], 1):
st.image(fig, caption=f"Figure {idx}")
if visuals["tables"]:
with st.expander(f"πŸ“Š Tables ({len(visuals['tables'])})"):
for idx, tbl in enumerate(visuals["tables"], 1):
st.image(tbl, caption=f"Table {idx}")
# Custom styling
st.markdown("""
<style>
[data-testid=stSidebar] {
background: #fafafa;
border-right: 1px solid #eee;
}
.stChatMessage {
padding: 1rem;
margin: 0.5rem 0;
border-radius: 10px;
box-shadow: 0 2px 5px rgba(0,0,0,0.1);
}
[data-testid=stVerticalBlock] > div:has(>.stChatMessage) {
gap: 0.5rem;
}
</style>
""", unsafe_allow_html=True)