|
import streamlit as st |
|
from pypdf import PdfReader |
|
|
|
import os |
|
from pathlib import Path |
|
from dotenv import load_dotenv |
|
import pickle |
|
import timeit |
|
from PIL import Image |
|
import datetime |
|
import base64 |
|
|
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain.vectorstores import FAISS |
|
from langchain.document_loaders import PyPDFLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.document_loaders import PyPDFLoader, DirectoryLoader |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain.prompts.prompt import PromptTemplate |
|
from langchain.llms import LlamaCpp |
|
from langchain.callbacks.manager import CallbackManager |
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler |
|
from langchain.vectorstores import Chroma |
|
from langchain.document_loaders import PyPDFDirectoryLoader |
|
from langchain.retrievers import BM25Retriever, EnsembleRetriever |
|
from langchain.chat_models import ChatOpenAI |
|
from langchain.agents.agent_toolkits import create_retriever_tool |
|
from langchain.agents.agent_toolkits import create_conversational_retrieval_agent |
|
from langchain.utilities import SerpAPIWrapper |
|
|
|
from utils import build_embedding_model, build_llm |
|
from utils import load_retriver,load_vectorstore, load_conversational_retrievel_chain |
|
|
|
load_dotenv() |
|
|
|
current_timestamp = datetime.datetime.now() |
|
timestamp_string = current_timestamp.strftime("%Y-%m-%d %H:%M:%S") |
|
|
|
|
|
persist_directory= "Database/PDF_HTML_CHROMA_DB" |
|
all_docs_pkl_directory= 'Database/text_chunks_html_pdf.pkl' |
|
|
|
|
|
if "llm" not in st.session_state: |
|
st.session_state["llm"] = build_llm() |
|
|
|
if "embeddings" not in st.session_state: |
|
st.session_state["embeddings"] = build_embedding_model() |
|
|
|
if "vector_db" not in st.session_state: |
|
st.session_state["vector_db"] = load_vectorstore(persist_directory=persist_directory, embeddings=st.session_state["embeddings"]) |
|
|
|
|
|
|
|
|
|
if "retriever" not in st.session_state: |
|
st.session_state["retriever"] = load_retriver(chroma_vectorstore=st.session_state["vector_db"]) |
|
|
|
if "conversation_chain" not in st.session_state: |
|
st.session_state["conversation_chain"] = load_conversational_retrievel_chain(retriever=st.session_state["retriever"], llm=st.session_state["llm"]) |
|
|
|
|
|
|
|
|
|
st.set_page_config( |
|
page_title="OMP Search Bot", |
|
layout="wide", |
|
initial_sidebar_state="expanded", |
|
) |
|
|
|
st.markdown(""" |
|
<style> |
|
.block-container { |
|
padding-top: 2.2rem} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
col1, col2= st.columns(2) |
|
|
|
title1 = """ |
|
<p style="font-size: 26px;text-align: right; color: #0C3453; font-weight: bold">OPM Retirement Services Assistant</p> |
|
""" |
|
|
|
def clear_chat_history(): |
|
""" |
|
Clear chat and start new chat |
|
""" |
|
st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}] |
|
|
|
|
|
file_ = open("opm_logo.png", "rb") |
|
contents = file_.read() |
|
data_url = base64.b64encode(contents).decode("utf-8") |
|
file_.close() |
|
|
|
st.markdown( |
|
f""" |
|
<div style="background-color: white; padding: 15px; border-radius: 10px;"> |
|
<div style="display: flex; justify-content: space-between;"> |
|
<div> |
|
<img src="data:image/png;base64,{data_url}" style="max-width: 100%;" alt="OPM Logo" /> |
|
</div> |
|
<div style="flex: 1; padding: 15px;"> |
|
{title1} |
|
""", |
|
unsafe_allow_html=True |
|
) |
|
st.write("") |
|
|
|
|
|
st.write('<p style="color: #B0B0B0;margin: 0;">OPM is here to help you transition from serving the American people to enjoying your retirement. This retirement services assistant shows our commitment to supporting new and existing retirees throughout the retirement journey. Our assistant is trained on 1500+ documents related to OPM retirement services and can answer your questions in conversational style. Just ask away..</p>', unsafe_allow_html=True) |
|
|
|
st.markdown("""---""") |
|
|
|
text_html = """ |
|
<p style="font-size: 24px; text-align: center; color:blue; margin: 0;"> |
|
Type your question below in conversational style language. |
|
</p> |
|
<p style="font-size: 18px; text-align: center; color: blue; margin: 0;"> |
|
Sample Questions:<br> |
|
will I get paid for my unused annual leave? <br> |
|
what annuity estimates do I need? <br> |
|
what are interim benefits? |
|
</p> |
|
""" |
|
|
|
st.write(text_html, unsafe_allow_html=True) |
|
|
|
|
|
with st.sidebar: |
|
st.subheader("") |
|
|
|
if st.session_state["vector_db"] and st.session_state["llm"]: |
|
|
|
if "messages" not in st.session_state.keys(): |
|
st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?", "Source":""}] |
|
|
|
|
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"]): |
|
st.write(message["content"]) |
|
if message["Source"]=="": |
|
st.write("") |
|
else: |
|
with st.expander("source"): |
|
for idx, item in enumerate(message["Source"]): |
|
st.markdown(item["Page"]) |
|
st.markdown(item["Source"]) |
|
st.markdown(item["page_content"]) |
|
st.write("---") |
|
|
|
|
|
|
|
if "stored_session" not in st.session_state: |
|
st.session_state["stored_session"] = [] |
|
|
|
|
|
if "expanders" not in st.session_state: |
|
st.session_state["expanders"] = [] |
|
|
|
|
|
def add_chat_expander(chat_history): |
|
current_timestamp = datetime.datetime.now() |
|
timestamp_string = current_timestamp.strftime("%Y-%m-%d %H:%M:%S") |
|
st.session_state["expanders"].append({"timestamp": timestamp_string, "chat_history": chat_history}) |
|
|
|
def clear_chat_history(): |
|
""" |
|
To remove existing chat history and start new conversation |
|
""" |
|
stored_session = [] |
|
for dict_message in st.session_state.messages: |
|
if dict_message["role"] == "user": |
|
string_dialogue = "User: " + dict_message["content"] + "\n\n" |
|
st.session_state["stored_session"].append(string_dialogue) |
|
|
|
else: |
|
string_dialogue = "Assistant: " + dict_message["content"] + "\n\n" |
|
st.session_state["stored_session"].append(string_dialogue) |
|
stored_session.append(string_dialogue) |
|
|
|
|
|
add_chat_expander(stored_session) |
|
st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?", "Source":""}] |
|
|
|
st.sidebar.button('New chat', on_click=clear_chat_history, use_container_width=True) |
|
st.sidebar.text("") |
|
st.sidebar.write('<p style="font-size: 16px;text-align: center; color: #727477; font-weight: bold">Chat history</p>', unsafe_allow_html=True) |
|
|
|
for expander_info in st.session_state["expanders"]: |
|
with st.sidebar.expander("Conversation ended at:"+"\n\n"+expander_info["timestamp"]): |
|
for message in expander_info["chat_history"]: |
|
if message.startswith("User:"): |
|
st.write(f'<span style="color: #EF6A6A;">{message}</span>', unsafe_allow_html=True) |
|
elif message.startswith("Assistant:"): |
|
st.write(f'<span style="color: #F7BD45;">{message}</span>', unsafe_allow_html=True) |
|
else: |
|
st.write(message) |
|
|
|
|
|
def generate_llm_response(conversation_chain, prompt_input): |
|
|
|
res = conversation_chain(prompt_input) |
|
return res['result'] |
|
|
|
|
|
|
|
if prompt := st.chat_input(disabled= not st.session_state["vector_db"]): |
|
st.session_state.messages.append({"role": "user", "content": prompt, "Source":""}) |
|
with st.chat_message("user"): |
|
st.write(prompt) |
|
|
|
|
|
if st.session_state.messages[-1]["role"] != "assistant": |
|
with st.chat_message("assistant"): |
|
with st.spinner("Searching..."): |
|
start = timeit.default_timer() |
|
response = generate_llm_response(conversation_chain=st.session_state["conversation_chain"], prompt_input=prompt) |
|
placeholder = st.empty() |
|
full_response = '' |
|
for item in response: |
|
full_response += item |
|
placeholder.markdown(full_response) |
|
if response: |
|
st.text("-------------------------------------") |
|
docs= st.session_state["retriever"].get_relevant_documents(prompt) |
|
source_doc_list= [] |
|
for doc in docs: |
|
source_doc_list.append(doc.dict()) |
|
merged_source_doc= [] |
|
with st.expander("source"): |
|
for idx, item in enumerate(source_doc_list): |
|
source_doc = {"Page": f"Source {idx + 1}", "Source": f"**Source:** {item['metadata']['source'].split('/')[-1]}", |
|
"page_content":item["page_content"]} |
|
merged_source_doc.append(source_doc) |
|
st.markdown(f"Source {idx + 1}") |
|
st.markdown(f"**Source:** {item['metadata']['source'].split('/')[-1]}") |
|
st.markdown(item["page_content"]) |
|
st.write("---") |
|
message = {"role": "assistant", "content": full_response, "Source":merged_source_doc} |
|
st.session_state.messages.append(message) |
|
st.markdown("👍 👎 Create Ticket") |
|
|
|
|
|
|
|
|
|
end = timeit.default_timer() |
|
print(f"Time to retrieve response: {end - start}") |
|
|