zliang's picture
Update app.py
eda4d8c verified
raw
history blame
12.5 kB
import os
import time
import io
import base64
import re
import numpy as np
import fitz # PyMuPDF
import tempfile
from PIL import Image
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import cosine_similarity
from ultralytics import YOLO
import streamlit as st
from streamlit_chat import message
from langchain_core.output_parsers import StrOutputParser
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_text_splitters import SpacyTextSplitter
from langchain_core.prompts import ChatPromptTemplate
from streamlit.runtime.scriptrunner import get_script_run_ctx
from streamlit import runtime
# Initialize models and environment
os.system("python -m spacy download en_core_web_sm")
model = YOLO("best.pt")
openai_api_key = os.environ.get("openai_api_key")
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
# Utility functions
@st.cache_data(show_spinner=False, ttl=3600)
def clean_text(text):
return re.sub(r'\s+', ' ', text).strip()
def remove_references(text):
reference_patterns = [
r'\bReferences\b', r'\breferences\b', r'\bBibliography\b',
r'\bCitations\b', r'\bWorks Cited\b', r'\bReference\b'
]
lines = text.split('\n')
for i, line in enumerate(lines):
if any(re.search(pattern, line, re.IGNORECASE) for pattern in reference_patterns):
return '\n'.join(lines[:i])
return text
def handle_errors(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
st.session_state.chat_history.append({
"bot": f"❌ An error occurred: {str(e)}"
})
st.rerun()
return wrapper
def scroll_to_bottom():
ctx = get_script_run_ctx()
if ctx and runtime.exists():
js = """
<script>
function scrollToBottom() {
window.parent.document.querySelector('section.main').scrollTo(0, window.parent.document.querySelector('section.main').scrollHeight);
}
setTimeout(scrollToBottom, 100);
</script>
"""
st.components.v1.html(js, height=0)
# Core processing functions
@st.cache_data(show_spinner=False, ttl=3600)
@handle_errors
def summarize_pdf(_pdf_file_path, num_clusters=10):
embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small", api_key=openai_api_key)
llm = ChatOpenAI(model="gpt-4", api_key=openai_api_key, temperature=0.3)
# Load PDF with page numbers
loader = PyMuPDFLoader(_pdf_file_path)
docs = loader.load()
# Create chunks with page metadata
text_splitter = SpacyTextSplitter(chunk_size=500)
chunks_with_metadata = []
for doc in docs:
chunks = text_splitter.split_text(doc.page_content)
for chunk in chunks:
chunks_with_metadata.append({
"text": clean_text(chunk),
"page": doc.metadata["page"] + 1 # Convert to 1-based numbering
})
# Prepare prompt with citation instructions
prompt = ChatPromptTemplate.from_template(
"""Generate a comprehensive summary with inline citations using [Source X] format.
Include these elements:
1. Key findings and conclusions
2. Main methodologies used
3. Important data points
4. Limitations mentioned
Structure your response as:
## Comprehensive Summary
{summary_content}
Contexts: {topic}"""
)
# Generate summary
chain = prompt | llm | StrOutputParser()
raw_summary = chain.invoke({
"topic": ' '.join([chunk["text"] for chunk in chunks_with_metadata])
})
return generate_interactive_citations(raw_summary, chunks_with_metadata)
def generate_interactive_citations(summary_text, source_chunks):
# Create source entries with page numbers and full text
sources_html = """<div style="margin-top: 2rem; padding-top: 1rem; border-top: 1px solid #e0e0e0;">
<h3 style="color: #2c3e50;">πŸ“– Source References</h3>"""
source_mapping = {}
for idx, chunk in enumerate(source_chunks):
source_id = f"source-{idx+1}"
source_mapping[idx+1] = {
"id": source_id,
"page": chunk["page"],
"text": chunk["text"]
}
sources_html += f"""
<div id="{source_id}" style="margin: 1rem 0; padding: 1rem;
border: 1px solid #e0e0e0; border-radius: 8px;
background-color: #f8f9fa; transition: all 0.3s ease;">
<div style="display: flex; justify-content: space-between; align-items: center;">
<div style="font-weight: 600; color: #4CAF50;">Source {idx+1}</div>
<div style="font-size: 0.9em; color: #666;">Page {chunk['page']}</div>
</div>
<div style="margin-top: 0.5rem; color: #444; font-size: 0.95em;">
{chunk["text"]}
</div>
</div>
"""
sources_html += "</div>"
# Add click interactions
interaction_js = """
<script>
document.querySelectorAll('.citation-link').forEach(item => {
item.addEventListener('click', function(e) {
e.preventDefault();
const sourceId = this.getAttribute('data-source');
const sourceDiv = document.getElementById(sourceId);
// Highlight animation
sourceDiv.style.transform = 'scale(1.02)';
sourceDiv.style.boxShadow = '0 4px 12px rgba(76,175,80,0.2)';
setTimeout(() => {
sourceDiv.style.transform = 'none';
sourceDiv.style.boxShadow = 'none';
}, 500);
// Smooth scroll
sourceDiv.scrollIntoView({behavior: 'smooth', block: 'start'});
});
});
</script>
"""
# Replace citations with interactive links
cited_summary = re.sub(r'\[Source (\d+)\]',
lambda m: f'<a class="citation-link" data-source="source-{m.group(1)}" '
f'style="cursor: pointer; color: #4CAF50; text-decoration: none; '
f'border-bottom: 1px dashed #4CAF50;">[Source {m.group(1)}]</a>',
summary_text)
return f"""
<div style="margin-bottom: 3rem;">
{cited_summary}
{sources_html}
</div>
{interaction_js}
"""
@st.cache_data(show_spinner=False, ttl=3600)
@handle_errors
def qa_pdf(_pdf_file_path, query, num_clusters=5):
embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small", api_key=openai_api_key)
llm = ChatOpenAI(model="gpt-4", api_key=openai_api_key, temperature=0.3)
# Load PDF with page numbers
loader = PyMuPDFLoader(_pdf_file_path)
docs = loader.load()
# Create chunks with page metadata
text_splitter = SpacyTextSplitter(chunk_size=500)
chunks_with_metadata = []
for doc in docs:
chunks = text_splitter.split_text(doc.page_content)
for chunk in chunks:
chunks_with_metadata.append({
"text": clean_text(chunk),
"page": doc.metadata["page"] + 1
})
# Find relevant chunks
embeddings = embeddings_model.embed_documents([chunk["text"] for chunk in chunks_with_metadata])
query_embedding = embeddings_model.embed_query(query)
similarities = cosine_similarity([query_embedding], embeddings)[0]
top_indices = np.argsort(similarities)[-num_clusters:]
# Prepare prompt with citation instructions
prompt = ChatPromptTemplate.from_template(
"""Answer this question with inline citations using [Source X] format:
{question}
Use these verified sources:
{context}
Structure your answer with:
- Clear section headings
- Bullet points for lists
- Citations for all factual claims"""
)
chain = prompt | llm | StrOutputParser()
raw_answer = chain.invoke({
"question": query,
"context": '\n\n'.join([f"Source {i+1} (Page {chunks_with_metadata[i]['page']}): {chunks_with_metadata[i]['text']}"
for i in top_indices])
})
return generate_interactive_citations(raw_answer, [chunks_with_metadata[i] for i in top_indices])
# (Keep the rest of the code from previous implementation for PDF processing and UI)
# [Include the process_pdf, image_to_base64, and Streamlit UI code from previous response]
# [Make sure to maintain all the UI improvements and error handling]
# Streamlit UI Configuration
st.set_page_config(
page_title="PDF Research Assistant",
page_icon="πŸ“„",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom CSS Styles
st.markdown("""
<style>
.citation-link {
transition: all 0.2s ease;
font-weight: 500;
}
.citation-link:hover {
color: #45a049 !important;
border-bottom-color: #45a049 !important;
}
.stChatMessage {
border-radius: 12px;
box-shadow: 0 4px 12px rgba(0,0,0,0.08);
margin: 1.5rem 0;
padding: 1.5rem;
}
.stButton>button {
background: linear-gradient(135deg, #4CAF50, #45a049);
transition: transform 0.2s ease, box-shadow 0.2s ease;
}
.stButton>button:hover {
transform: translateY(-1px);
box-shadow: 0 4px 12px rgba(76,175,80,0.3);
}
[data-testid="stFileUploader"] {
border: 2px dashed #4CAF50;
border-radius: 12px;
background: #f8fff8;
}
</style>
""", unsafe_allow_html=True)
# Session state initialization
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
if 'current_file' not in st.session_state:
st.session_state.current_file = None
# Main UI
st.title("πŸ“„ Academic PDF Analyzer")
st.markdown("""
<div style="border-left: 4px solid #4CAF50; padding-left: 1.5rem; margin: 2rem 0;">
<p style="color: #2c3e50; font-size: 1.1rem;">πŸ” Upload research papers to:
<ul style="color: #2c3e50; font-size: 1rem;">
<li>Generate citations-backed summaries</li>
<li>Trace claims to original sources</li>
<li>Extract data tables and figures</li>
<li>Q&A with verifiable references</li>
</ul>
</p>
</div>
""", unsafe_allow_html=True)
# File uploader
uploaded_file = st.file_uploader(
"Upload research PDF",
type="pdf",
help="Maximum file size: 50MB",
on_change=lambda: setattr(st.session_state, 'chat_history', [])
)
if uploaded_file and uploaded_file.size > MAX_FILE_SIZE:
st.error("File size exceeds 50MB limit")
st.stop()
# Document processing
if uploaded_file:
file_path = tempfile.NamedTemporaryFile(delete=False).name
with open(file_path, "wb") as f:
f.write(uploaded_file.getbuffer()οΌ‰
# Chat interface
chat_container = st.container()
with chat_container:
for idx, chat in enumerate(st.session_state.chat_history):
col1, col2 = st.columns([1, 4])
if chat.get("user"):
with col2:
message(chat["user"], is_user=True, key=f"user_{idx}")
if chat.get("bot"):
with col1:
message(chat["bot"], key=f"bot_{idx}", allow_html=True)
scroll_to_bottom()
# Interaction controls
with st.container():
col1, col2, col3 = st.columns([3, 2, 2])
with col1:
user_input = st.chat_input("Ask a research question...")
with col2:
if st.button("πŸ“„ Generate Summary", use_container_width=True):
with st.spinner("Analyzing document structure..."):
summary = summarize_pdf(file_path)
st.session_state.chat_history.append({
"bot": f"## Research Summary\n{summary}"
})
st.rerun()
with col3:
if st.button("πŸ”„ Clear Session", use_container_width=True):
st.session_state.chat_history = []
st.rerun()
# Handle user questions
if user_input:
st.session_state.chat_history.append({"user": user_input})
with st.spinner("Verifying sources..."):
answer = qa_pdf(file_path, user_input)
st.session_state.chat_history[-1]["bot"] = f"## Research Answer\n{answer}"
st.rerun()