from dotenv import load_dotenv from typing_extensions import List, TypedDict from langchain_openai import ChatOpenAI, OpenAIEmbeddings from langchain_core.documents import Document from langchain_core.prompts import ChatPromptTemplate from langchain_qdrant import QdrantVectorStore from langchain_huggingface import HuggingFaceEmbeddings from langgraph.graph import START, StateGraph from langchain.prompts import ChatPromptTemplate from langchain_community.document_loaders import DirectoryLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from qdrant_client.http.models import Distance, VectorParams # Necessary for dependencies for DirectoryLoader import nltk nltk.download('punkt_tab') nltk.download('averaged_perceptron_tagger_eng') # Chunk configuration CHUNK_SIZE = 1000 CHUNK_OVERLAP = CHUNK_SIZE // 2 # RAG prompt template RAG_PROMPT = """\ You are a helpful assistant who helps Shopify merchants automate their businesses. Your goal is to provide a helpful response to the merchant's question in straight forward, non technical language. Try to be brief and to the point, but explain technical jargon. You must only use the provided context, and cannot use your own knowledge. ### Question {question} ### Context {context} """ class RagGraph: def __init__(self, qdrant_client, use_finetuned_embeddings=False): self.llm = ChatOpenAI(model="gpt-4-turbo-preview", streaming=True) self.collection_name = "rag_collection" if not use_finetuned_embeddings else "rag_collection_finetuned" self.embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small") \ if not use_finetuned_embeddings else HuggingFaceEmbeddings(model_name="thomfoolery/AIE5-MidTerm-finetuned-embeddings") self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP) self.rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT) self.qdrant_client = qdrant_client does_collection_exist = self.qdrant_client.collection_exists(collection_name=self.collection_name) dimension_size = 1536 if not use_finetuned_embeddings else 1024 print(f"Collection {self.collection_name} exists: {does_collection_exist}") if not does_collection_exist: qdrant_client.create_collection( collection_name=self.collection_name, vectors_config=VectorParams(size=dimension_size, distance=Distance.COSINE), ) self.vector_store = QdrantVectorStore( client=qdrant_client, collection_name=self.collection_name, embedding=self.embeddings_model, ) if not does_collection_exist: loader = DirectoryLoader("data/scraped/clean", glob="*.txt") documents = self.text_splitter.split_documents(loader.load()) self.vector_store.add_documents(documents=documents) self.vector_db_retriever = self.vector_store.as_retriever(search_kwargs={"k": 5}) self.graph = None self.create() def create(self): """Create the RAG graph.""" class State(TypedDict): """State for the conversation.""" question: str context: List[Document] def retrieve(state): question = state["question"] context = self.vector_db_retriever.invoke(question) return {"question": state["question"], "context": context} async def stream(state): """LangGraph node that streams responses""" question = state["question"] context = "\n\n".join(doc.page_content for doc in state["context"]) messages = self.rag_prompt.format_messages(question=question, context=context) async for chunk in self.llm.astream(messages): yield {"content": chunk.content} graph_builder = StateGraph(State).add_sequence([retrieve, stream]) graph_builder.add_edge(START, "retrieve") self.graph = graph_builder.compile() def run(self, question): """Invoke RAG response without streaming.""" chunks = self.vector_db_retriever.invoke(question) context = "\n\n".join(doc.page_content for doc in chunks) messages = self.rag_prompt.format_messages(question=question, context=context) response = self.llm.invoke(messages) return { "response": response.content, "context": chunks } async def stream(self, question, msg): """Stream RAG response.""" async for event in self.graph.astream({"question": question, "context": []}, stream_mode=["messages"]): _event_name, (message_chunk, _metadata) = event if message_chunk.content: await msg.stream_token(message_chunk.content) await msg.send() # Run RAG with CLI (no streaming) def main(): """Test the RAG graph.""" load_dotenv() rag_graph = RagGraph() # rag_graph.update_vector_store("data/scraped/clean", replace_documents=False) rag_graph.create_rag_graph() response = rag_graph.run("What is Shopify Flow?") print(response["response"]) if __name__ == "__main__": main()