Spaces:
Running
Running
import streamlit as st | |
import os | |
from typing import List, Tuple, Optional | |
from pinecone import Pinecone | |
from langchain_pinecone import PineconeVectorStore | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_openai import ChatOpenAI | |
from langchain_core.prompts import PromptTemplate | |
from dotenv import load_dotenv | |
from RAG import RAG | |
import logging | |
from image_scraper import DigitalCommonwealthScraper | |
import shutil | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Page configuration | |
st.set_page_config( | |
page_title="Boston Public Library Chatbot", | |
page_icon="π€", | |
layout="wide" | |
) | |
def initialize_models() -> Tuple[Optional[ChatOpenAI], HuggingFaceEmbeddings]: | |
"""Initialize the language model and embeddings.""" | |
try: | |
load_dotenv() | |
if "llm" not in st.session_state: | |
# Initialize OpenAI model | |
st.session_state.llm = ChatOpenAI( | |
model="gpt-3.5-turbo", | |
temperature=0, | |
timeout=60, # Added reasonable timeout | |
max_retries=2 | |
) | |
if "embeddings" not in st.session_state: | |
# Initialize embeddings | |
st.session_state.embeddings = HuggingFaceEmbeddings( | |
model_name="sentence-transformers/all-mpnet-base-v2" | |
#model_name="sentence-transformers/all-MiniLM-L6-v2" | |
) | |
if "pinecone" not in st.session_state: | |
pinecone_api_key = os.getenv("PINECONE_API_KEY") | |
INDEX_NAME = 'bpl-test' | |
#initialize vectorstore | |
pc = Pinecone(api_key=pinecone_api_key) | |
index = pc.Index(INDEX_NAME) | |
st.session_state.pinecone = PineconeVectorStore(index=index, embedding=st.session_state.embeddings) | |
if "vectorstore" not in st.session_state: | |
#st.session_state.vectorstore = CloudSQLVectorStore(embedding=st.session_state.embeddings) | |
st.session_state.vectorstore = st.session_state.pinecone | |
except Exception as e: | |
logger.error(f"Error initializing models: {str(e)}") | |
st.error(f"Failed to initialize models: {str(e)}") | |
return None, None | |
def process_message( | |
query: str, | |
llm: ChatOpenAI, | |
vectorstore: PineconeVectorStore, | |
) -> Tuple[str, List]: | |
"""Process the user message using the RAG system.""" | |
try: | |
response, sources = RAG( | |
query=query, | |
llm=llm, | |
vectorstore=vectorstore, | |
) | |
return response, sources | |
except Exception as e: | |
logger.error(f"Error in process_message: {str(e)}") | |
return f"Error processing message: {str(e)}", [] | |
def display_sources(sources: List) -> None: | |
"""Display sources with minimal output: content preview, source, URL, and image/audio if available.""" | |
if not sources: | |
st.info("No sources available for this response.") | |
return | |
st.subheader("Sources") | |
for doc in sources: | |
try: | |
metadata = doc.metadata | |
source = metadata.get("source", "Unknown Source") | |
title = metadata.get("title_info_primary_tsi", "Unknown Title") | |
format_type = metadata.get("format", "").lower() | |
is_audio = "audio" in format_type | |
expander_title = f"π {title}" if is_audio else title | |
with st.expander(expander_title): | |
# Content preview | |
if hasattr(doc, 'page_content'): | |
st.markdown(f"**Content:** {doc.page_content[:300]} ...") | |
# URL building | |
doc_url = metadata.get("URL", "").strip() | |
if not doc_url and source: | |
doc_url = f"https://www.digitalcommonwealth.org/search/{source}" | |
st.markdown(f"**Source ID:** {source}") | |
st.markdown(f"**Format:** {format_type if format_type else 'Not specified'}") | |
st.markdown(f"**URL:** {doc_url}") | |
# π Try to show audio if it's an audio entry and there's a media file | |
if is_audio: | |
# Try to find a playable media file β if metadata has audio URLs | |
# For now, just embed a dummy player or placeholder | |
st.info("This is an audio entry.") | |
# Optionally: | |
# st.audio("https://example.com/audio-file.mp3") # replace with real audio URL | |
else: | |
# πΌοΈ Show image if it's not audio | |
scraper = DigitalCommonwealthScraper() | |
images = scraper.extract_images(doc_url) | |
images = images[:1] | |
if images: | |
output_dir = 'downloaded_images' | |
if os.path.exists(output_dir): | |
shutil.rmtree(output_dir) | |
downloaded_files = scraper.download_images(images) | |
st.image(downloaded_files, width=400, caption=[ | |
img.get('alt', f'Image') for img in images | |
]) | |
except Exception as e: | |
logger.warning(f"[display_sources] Error displaying document: {e}") | |
st.error("Error displaying one of the sources.") | |
def main(): | |
st.title("Digital Commonwealth RAG π€") | |
INDEX_NAME = 'bpl-rag' | |
# Initialize session state | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
if "show_settings" not in st.session_state: | |
st.session_state.show_settings = False | |
if "num_sources" not in st.session_state: | |
st.session_state.num_sources = 10 | |
initialize_models() | |
# π΅ Settings button | |
open_settings = st.button("βοΈ Settings") | |
if open_settings: | |
st.session_state.show_settings = True | |
if st.session_state.show_settings: | |
with st.container(): | |
st.markdown("---") | |
st.markdown("### βοΈ Settings") | |
num_sources = st.number_input( | |
"Number of Sources to Display", | |
min_value=1, | |
max_value=100, | |
value=st.session_state.num_sources, | |
step=1, | |
) | |
st.session_state.num_sources = num_sources | |
close_settings = st.button("β Close Settings") | |
if close_settings: | |
st.session_state.show_settings = False | |
st.markdown("---") | |
# Show chat history | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# β¬οΈ CHAT INPUT BOX always stuck to bottom | |
user_input = st.chat_input("Type your question here...") | |
if user_input: | |
with st.chat_message("user"): | |
st.markdown(user_input) | |
st.session_state.messages.append({"role": "user", "content": user_input}) | |
with st.chat_message("assistant"): | |
with st.spinner("Thinking... Please be patient..."): | |
response, sources = process_message( | |
query=user_input, | |
llm=st.session_state.llm, | |
vectorstore=st.session_state.vectorstore | |
) | |
if isinstance(response, str): | |
st.markdown(response) | |
st.session_state.messages.append({ | |
"role": "assistant", | |
"content": response | |
}) | |
display_sources(sources[:int(st.session_state.num_sources)]) | |
else: | |
st.error("Received an invalid response format") | |
# Footer (optional, will be above chat input) | |
st.markdown("---") | |
st.markdown( | |
"Built with Langchain + Streamlit + Pinecone", | |
help="Natural Language Querying for Digital Commonwealth" | |
) | |
st.markdown( | |
"The Digital Commonwealth site provides access to photographs, manuscripts, books, " | |
"audio recordings, and other materials of historical interest that have been digitized " | |
"and made available by members of Digital Commonwealth." | |
) | |
if __name__ == "__main__": | |
main() |