Spaces:
Runtime error
Runtime error
############################################################################################################################# | |
# Filename : app.py | |
# Description: A Streamlit application to showcase how RAG works. | |
# Author : Georgios Ioannou | |
# | |
# Copyright © 2024 by Georgios Ioannou | |
############################################################################################################################# | |
# Import libraries. | |
import os | |
import streamlit as st | |
from dotenv import load_dotenv, find_dotenv | |
from huggingface_hub import InferenceClient | |
from langchain.prompts import PromptTemplate | |
from langchain.schema import Document | |
from langchain.schema.runnable import RunnablePassthrough, RunnableLambda | |
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings | |
from langchain_community.vectorstores import MongoDBAtlasVectorSearch | |
from pymongo import MongoClient | |
from pymongo.collection import Collection | |
from typing import Dict, Any | |
############################################################################################################################# | |
class RAGQuestionAnswering: | |
def __init__(self): | |
""" | |
Parameters | |
---------- | |
None | |
Output | |
------ | |
None | |
Purpose | |
------- | |
Initializes the RAG Question Answering system by setting up configuration | |
and loading environment variables. | |
Assumptions | |
----------- | |
- Expects .env file with MONGO_URI and HF_TOKEN | |
- Requires proper MongoDB setup with vector search index | |
- Needs connection to Hugging Face API | |
Notes | |
----- | |
This is the main class that handles all RAG operations | |
""" | |
self.load_environment() | |
self.setup_mongodb() | |
self.setup_embedding_model() | |
self.setup_vector_search() | |
self.setup_rag_chain() | |
def load_environment(self) -> None: | |
""" | |
Parameters | |
---------- | |
None | |
Output | |
------ | |
None | |
Purpose | |
------- | |
Loads environment variables from .env file and sets up configuration constants. | |
Assumptions | |
----------- | |
Expects a .env file with MONGO_URI and HF_TOKEN defined | |
Notes | |
----- | |
Will stop the application if required environment variables are missing | |
""" | |
load_dotenv(find_dotenv()) | |
self.MONGO_URI = os.getenv("MONGO_URI") | |
self.HF_TOKEN = os.getenv("HF_TOKEN") | |
if not self.MONGO_URI or not self.HF_TOKEN: | |
st.error("Please ensure MONGO_URI and HF_TOKEN are set in your .env file") | |
st.stop() | |
# MongoDB configuration. | |
self.DB_NAME = "txts" | |
self.COLLECTION_NAME = "txts_collection" | |
self.VECTOR_SEARCH_INDEX = "vector_index" | |
def setup_mongodb(self) -> None: | |
""" | |
Parameters | |
---------- | |
None | |
Output | |
------ | |
None | |
Purpose | |
------- | |
Initializes the MongoDB connection and sets up the collection. | |
Assumptions | |
----------- | |
- Valid MongoDB URI is available | |
- Database and collection exist in MongoDB Atlas | |
Notes | |
----- | |
Uses st.cache_resource for efficient connection management | |
""" | |
def init_mongodb() -> Collection: | |
cluster = MongoClient(self.MONGO_URI) | |
return cluster[self.DB_NAME][self.COLLECTION_NAME] | |
self.mongodb_collection = init_mongodb() | |
def setup_embedding_model(self) -> None: | |
""" | |
Parameters | |
---------- | |
None | |
Output | |
------ | |
None | |
Purpose | |
------- | |
Initializes the embedding model for vector search. | |
Assumptions | |
----------- | |
- Valid Hugging Face API token | |
- Internet connection to access the model | |
Notes | |
----- | |
Uses the all-mpnet-base-v2 model from sentence-transformers | |
""" | |
def init_embedding_model() -> HuggingFaceInferenceAPIEmbeddings: | |
return HuggingFaceInferenceAPIEmbeddings( | |
api_key=self.HF_TOKEN, | |
model_name="sentence-transformers/all-mpnet-base-v2", | |
) | |
self.embedding_model = init_embedding_model() | |
def setup_vector_search(self) -> None: | |
""" | |
Parameters | |
---------- | |
None | |
Output | |
------ | |
None | |
Purpose | |
------- | |
Sets up the vector search functionality using MongoDB Atlas. | |
Assumptions | |
----------- | |
- MongoDB Atlas vector search index is properly configured | |
- Valid embedding model is initialized | |
Notes | |
----- | |
Creates a retriever with similarity search and score threshold | |
""" | |
def init_vector_search() -> MongoDBAtlasVectorSearch: | |
return MongoDBAtlasVectorSearch.from_connection_string( | |
connection_string=self.MONGO_URI, | |
namespace=f"{self.DB_NAME}.{self.COLLECTION_NAME}", | |
embedding=self.embedding_model, | |
index_name=self.VECTOR_SEARCH_INDEX, | |
) | |
self.vector_search = init_vector_search() | |
self.retriever = self.vector_search.as_retriever( | |
search_type="similarity", search_kwargs={"k": 10, "score_threshold": 0.85} | |
) | |
def format_docs(self, docs: list[Document]) -> str: | |
""" | |
Parameters | |
---------- | |
**docs:** list[Document] - List of documents to be formatted | |
Output | |
------ | |
str: Formatted string containing concatenated document content | |
Purpose | |
------- | |
Formats the retrieved documents into a single string for processing | |
Assumptions | |
----------- | |
Documents have page_content attribute | |
Notes | |
----- | |
Joins documents with double newlines for better readability | |
""" | |
return "\n\n".join(doc.page_content for doc in docs) | |
def generate_response(self, input_dict: Dict[str, Any]) -> str: | |
""" | |
Parameters | |
---------- | |
**input_dict:** Dict[str, Any] - Dictionary containing context and question | |
Output | |
------ | |
str: Generated response from the model | |
Purpose | |
------- | |
Generates a response using the Hugging Face model based on context and question | |
Assumptions | |
----------- | |
- Valid Hugging Face API token | |
- Input dictionary contains 'context' and 'question' keys | |
Notes | |
----- | |
Uses Qwen2.5-1.5B-Instruct model with controlled temperature | |
""" | |
hf_client = InferenceClient(api_key=self.HF_TOKEN) | |
formatted_prompt = self.prompt.format(**input_dict) | |
response = hf_client.chat.completions.create( | |
model="Qwen/Qwen2.5-1.5B-Instruct", | |
messages=[ | |
{"role": "system", "content": formatted_prompt}, | |
{"role": "user", "content": input_dict["question"]}, | |
], | |
max_tokens=1000, | |
temperature=0.2, | |
) | |
return response.choices[0].message.content | |
def setup_rag_chain(self) -> None: | |
""" | |
Parameters | |
---------- | |
None | |
Output | |
------ | |
None | |
Purpose | |
------- | |
Sets up the RAG chain for processing questions and generating answers | |
Assumptions | |
----------- | |
Retriever and response generator are properly initialized | |
Notes | |
----- | |
Creates a chain that combines retrieval and response generation | |
""" | |
self.prompt = PromptTemplate.from_template( | |
"""Use the following pieces of context to answer the question at the end. | |
START OF CONTEXT: | |
{context} | |
END OF CONTEXT: | |
START OF QUESTION: | |
{question} | |
END OF QUESTION: | |
If you do not know the answer, just say that you do not know. | |
NEVER assume things. | |
""" | |
) | |
self.rag_chain = { | |
"context": self.retriever | RunnableLambda(self.format_docs), | |
"question": RunnablePassthrough(), | |
} | RunnableLambda(self.generate_response) | |
def process_question(self, question: str) -> str: | |
""" | |
Parameters | |
---------- | |
**question:** str - The user's question to be answered | |
Output | |
------ | |
str: The generated answer to the question | |
Purpose | |
------- | |
Processes a user question through the RAG chain and returns an answer | |
Assumptions | |
----------- | |
- Question is a non-empty string | |
- RAG chain is properly initialized | |
Notes | |
----- | |
Main interface for question-answering functionality | |
""" | |
return self.rag_chain.invoke(question) | |
############################################################################################################################# | |
def setup_streamlit_ui() -> None: | |
""" | |
Parameters | |
---------- | |
None | |
Output | |
------ | |
None | |
Purpose | |
------- | |
Sets up the Streamlit user interface with proper styling and layout | |
Assumptions | |
----------- | |
- CSS file exists at ./static/styles/style.css | |
- Image file exists at ./static/images/ctp.png | |
Notes | |
----- | |
Handles all UI-related setup and styling | |
""" | |
st.set_page_config(page_title="RAG Question Answering", page_icon="🤖") | |
# Load CSS. | |
with open("./static/styles/style.css") as f: | |
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) | |
# Title and subtitles. | |
st.markdown( | |
'<h1 align="center" style="font-family: monospace; font-size: 2.1rem; margin-top: -4rem">RAG Question Answering</h1>', | |
unsafe_allow_html=True, | |
) | |
st.markdown( | |
'<h3 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: -2rem">Using Zoom Closed Captioning From The Lectures</h3>', | |
unsafe_allow_html=True, | |
) | |
st.markdown( | |
'<h2 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: 0rem">CUNY Tech Prep Tutorial 5</h2>', | |
unsafe_allow_html=True, | |
) | |
# Display logo. | |
left_co, cent_co, last_co = st.columns(3) | |
with cent_co: | |
st.image("./static/images/ctp.png") | |
############################################################################################################################# | |
def main(): | |
""" | |
Parameters | |
---------- | |
None | |
Output | |
------ | |
None | |
Purpose | |
------- | |
Main function that runs the Streamlit application | |
Assumptions | |
----------- | |
All required environment variables and files are present | |
Notes | |
----- | |
Entry point for the application | |
""" | |
# Setup UI. | |
setup_streamlit_ui() | |
# Initialize RAG system. | |
rag_system = RAGQuestionAnswering() | |
# Create input elements. | |
query = st.text_input("Question:", key="question_input") | |
# Handle submission. | |
if st.button("Submit", type="primary"): | |
if query: | |
with st.spinner("Generating response..."): | |
response = rag_system.process_question(query) | |
st.text_area("Answer:", value=response, height=200, disabled=True) | |
else: | |
st.warning("Please enter a question.") | |
# Add GitHub link. | |
st.markdown( | |
""" | |
<p align="center" style="font-family: monospace; color: #FAF9F6; font-size: 1rem;"> | |
<b>Check out our <a href="https://github.com/GeorgiosIoannouCoder/" style="color: #FAF9F6;">GitHub repository</a></b> | |
</p> | |
""", | |
unsafe_allow_html=True, | |
) | |
############################################################################################################################# | |
if __name__ == "__main__": | |
main() | |