from fastapi import FastAPI, HTTPException import os from typing import List, Dict from dotenv import load_dotenv import logging from pathlib import Path from langchain_community.document_loaders import PyPDFLoader from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_community.vectorstores import Qdrant as QdrantVectorStore from langchain_google_genai import GoogleGenerativeAIEmbeddings from langchain_groq import ChatGroq from qdrant_client import QdrantClient from qdrant_client.http.models import Distance, VectorParams from qdrant_client.models import PointIdsList from langgraph.graph import MessagesState, StateGraph, END from langgraph.checkpoint.memory import MemorySaver from langchain_core.messages import SystemMessage, HumanMessage logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Load environment variables load_dotenv() GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY') GROQ_API_KEY = os.getenv('GROQ_API_KEY') if not GOOGLE_API_KEY or not GROQ_API_KEY: raise ValueError("API keys not set in environment variables") app = FastAPI() class QASystem: def __init__(self): self.vector_store = None self.graph = None self.memory = MemorySaver() # LangGraph memory saver for conversation history self.embeddings = None self.client = None self.pdf_dir = "pdfss" def load_pdf_documents(self): documents = [] pdf_dir = Path(self.pdf_dir) if not pdf_dir.exists(): raise FileNotFoundError(f"PDF directory not found: {self.pdf_dir}") for pdf_path in pdf_dir.glob("*.pdf"): try: loader = PyPDFLoader(str(pdf_path)) documents.extend(loader.load()) logger.info(f"Loaded PDF: {pdf_path}") except Exception as e: logger.error(f"Error loading PDF {pdf_path}: {str(e)}") text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) split_docs = text_splitter.split_documents(documents) logger.info(f"Split documents into {len(split_docs)} chunks") return split_docs def initialize_system(self): try: # Qdrant setup self.client = QdrantClient(":memory:") try: self.client.get_collection("pdf_data") except Exception: self.client.create_collection( collection_name="pdf_data", vectors_config=VectorParams(size=768, distance=Distance.COSINE), ) logger.info("Created new collection: pdf_data") # Embeddings and vector store self.embeddings = GoogleGenerativeAIEmbeddings( model="models/embedding-001", google_api_key=GOOGLE_API_KEY ) self.vector_store = QdrantVectorStore( client=self.client, collection_name="pdf_data", embeddings=self.embeddings, ) # Load and add documents documents = self.load_pdf_documents() if documents: points = self.client.scroll(collection_name="pdf_data", limit=100)[0] if points: self.client.delete( collection_name="pdf_data", points_selector=PointIdsList(points=[p.id for p in points]) ) self.vector_store.add_documents(documents) logger.info(f"Added {len(documents)} documents to vector store") # LLM setup llm = ChatGroq( model="llama3-8b-8192", api_key=GROQ_API_KEY, temperature=0.7 ) # Graph building graph_builder = StateGraph(MessagesState) # === TOOL NODE for context fetching from Qdrant === def retrieve_documents(state: MessagesState): query = [m.content for m in state["messages"] if m.type == "human"][-1] results = self.vector_store.similarity_search(query, k=4) context = "\n\n".join([doc.page_content for doc in results]) return {"messages": [SystemMessage(content=context, name="retrieval")]} # as tool message # === GENERATOR NODE that uses full memory (chat history) === def generate_response(state: MessagesState): # Get full history from memory thread_id = state["configurable"].get("thread_id", "default") history = self.memory.get_memory(thread_id).get("messages", []) logger.info(f"[Thread {thread_id}] History: {[m.content for m in history]}") # Add current turn messages all_messages = history + state["messages"] # Extract context from retrieved docs (tool messages) retrieved_docs = [m for m in all_messages if m.type == "tool"] context = ' '.join(m.content for m in retrieved_docs) if retrieved_docs else "mountain bicycle documentation knowledge" # Compose system prompt system_prompt = ( "You are an AI assistant embedded within the Interactive Electronic Technical Manual (IETM) for Mountain Cycles. " "Your responses MUST be accurate, concise (5 sentences max). " "If you don't know the answer, say 'I don't know based on available data.'\n\n" f"Context:\n{context}" ) final_messages = [SystemMessage(content=system_prompt)] + all_messages response = llm.invoke(final_messages) # Save updated chat to memory self.memory.save_checkpoint(thread_id, {"messages": all_messages + [response]}) return {"messages": [response]} # Add graph nodes graph_builder.add_node("retrieval", retrieve_documents) graph_builder.add_node("generate", generate_response) # Graph edges graph_builder.set_entry_point("retrieval") graph_builder.add_edge("retrieval", "generate") graph_builder.add_edge("generate", END) # Compile graph with memory self.graph = graph_builder.compile(checkpointer=self.memory) return True except Exception as e: logger.error(f"System initialization error: {str(e)}") return False # === Query Processor with Memory === def process_query(self, query: str, user_id: str) -> List[Dict[str, str]]: try: responses = [] for step in self.graph.stream( {"messages": [HumanMessage(content=query)]}, stream_mode="values", config={"configurable": {"thread_id": user_id}} # thread ID for user memory ): if step["messages"]: responses.append({ 'content': step["messages"][-1].content, 'type': step["messages"][-1].type }) return responses except Exception as e: logger.error(f"Query processing error: {str(e)}") return [{'content': f"Query processing error: {str(e)}", 'type': 'error'}] # === Initialize QA System === qa_system = QASystem() if qa_system.initialize_system(): logger.info("QA System Initialized Successfully") else: raise RuntimeError("Failed to initialize QA System") # === FastAPI Route === @app.post("/query") async def query_api(query: str, user_id: str): # Pass user_id for session-specific memory responses = qa_system.process_query(query, user_id) return {"responses": responses}