Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
import os | |
from typing import List, Dict | |
from dotenv import load_dotenv | |
import logging | |
from pathlib import Path | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import Qdrant as QdrantVectorStore | |
from langchain_google_genai import GoogleGenerativeAIEmbeddings | |
from langchain_groq import ChatGroq | |
from qdrant_client import QdrantClient | |
from qdrant_client.http.models import Distance, VectorParams | |
from qdrant_client.models import PointIdsList | |
from langgraph.graph import MessagesState, StateGraph, END | |
from langgraph.checkpoint.memory import MemorySaver | |
from langchain_core.messages import SystemMessage, HumanMessage | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Load environment variables | |
load_dotenv() | |
GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY') | |
GROQ_API_KEY = os.getenv('GROQ_API_KEY') | |
if not GOOGLE_API_KEY or not GROQ_API_KEY: | |
raise ValueError("API keys not set in environment variables") | |
app = FastAPI() | |
class QASystem: | |
def __init__(self): | |
self.vector_store = None | |
self.graph = None | |
self.memory = MemorySaver() # LangGraph memory saver for conversation history | |
self.embeddings = None | |
self.client = None | |
self.pdf_dir = "pdfss" | |
def load_pdf_documents(self): | |
documents = [] | |
pdf_dir = Path(self.pdf_dir) | |
if not pdf_dir.exists(): | |
raise FileNotFoundError(f"PDF directory not found: {self.pdf_dir}") | |
for pdf_path in pdf_dir.glob("*.pdf"): | |
try: | |
loader = PyPDFLoader(str(pdf_path)) | |
documents.extend(loader.load()) | |
logger.info(f"Loaded PDF: {pdf_path}") | |
except Exception as e: | |
logger.error(f"Error loading PDF {pdf_path}: {str(e)}") | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
split_docs = text_splitter.split_documents(documents) | |
logger.info(f"Split documents into {len(split_docs)} chunks") | |
return split_docs | |
def initialize_system(self): | |
try: | |
# Qdrant setup | |
self.client = QdrantClient(":memory:") | |
try: | |
self.client.get_collection("pdf_data") | |
except Exception: | |
self.client.create_collection( | |
collection_name="pdf_data", | |
vectors_config=VectorParams(size=768, distance=Distance.COSINE), | |
) | |
logger.info("Created new collection: pdf_data") | |
# Embeddings and vector store | |
self.embeddings = GoogleGenerativeAIEmbeddings( | |
model="models/embedding-001", google_api_key=GOOGLE_API_KEY | |
) | |
self.vector_store = QdrantVectorStore( | |
client=self.client, | |
collection_name="pdf_data", | |
embeddings=self.embeddings, | |
) | |
# Load and add documents | |
documents = self.load_pdf_documents() | |
if documents: | |
points = self.client.scroll(collection_name="pdf_data", limit=100)[0] | |
if points: | |
self.client.delete( | |
collection_name="pdf_data", | |
points_selector=PointIdsList(points=[p.id for p in points]) | |
) | |
self.vector_store.add_documents(documents) | |
logger.info(f"Added {len(documents)} documents to vector store") | |
# LLM setup | |
llm = ChatGroq( | |
model="llama3-8b-8192", | |
api_key=GROQ_API_KEY, | |
temperature=0.7 | |
) | |
# Graph building | |
graph_builder = StateGraph(MessagesState) | |
# === TOOL NODE for context fetching from Qdrant === | |
def retrieve_documents(state: MessagesState): | |
query = [m.content for m in state["messages"] if m.type == "human"][-1] | |
results = self.vector_store.similarity_search(query, k=4) | |
context = "\n\n".join([doc.page_content for doc in results]) | |
return {"messages": [SystemMessage(content=context, name="retrieval")]} # as tool message | |
# === GENERATOR NODE that uses full memory (chat history) === | |
def generate_response(state: MessagesState): | |
# Get full history from memory | |
thread_id = state["configurable"].get("thread_id", "default") | |
history = self.memory.get_memory(thread_id).get("messages", []) | |
logger.info(f"[Thread {thread_id}] History: {[m.content for m in history]}") | |
# Add current turn messages | |
all_messages = history + state["messages"] | |
# Extract context from retrieved docs (tool messages) | |
retrieved_docs = [m for m in all_messages if m.type == "tool"] | |
context = ' '.join(m.content for m in retrieved_docs) if retrieved_docs else "mountain bicycle documentation knowledge" | |
# Compose system prompt | |
system_prompt = ( | |
"You are an AI assistant embedded within the Interactive Electronic Technical Manual (IETM) for Mountain Cycles. " | |
"Your responses MUST be accurate, concise (5 sentences max). " | |
"If you don't know the answer, say 'I don't know based on available data.'\n\n" | |
f"Context:\n{context}" | |
) | |
final_messages = [SystemMessage(content=system_prompt)] + all_messages | |
response = llm.invoke(final_messages) | |
# Save updated chat to memory | |
self.memory.save_checkpoint(thread_id, {"messages": all_messages + [response]}) | |
return {"messages": [response]} | |
# Add graph nodes | |
graph_builder.add_node("retrieval", retrieve_documents) | |
graph_builder.add_node("generate", generate_response) | |
# Graph edges | |
graph_builder.set_entry_point("retrieval") | |
graph_builder.add_edge("retrieval", "generate") | |
graph_builder.add_edge("generate", END) | |
# Compile graph with memory | |
self.graph = graph_builder.compile(checkpointer=self.memory) | |
return True | |
except Exception as e: | |
logger.error(f"System initialization error: {str(e)}") | |
return False | |
# === Query Processor with Memory === | |
def process_query(self, query: str, user_id: str) -> List[Dict[str, str]]: | |
try: | |
responses = [] | |
for step in self.graph.stream( | |
{"messages": [HumanMessage(content=query)]}, | |
stream_mode="values", | |
config={"configurable": {"thread_id": user_id}} # thread ID for user memory | |
): | |
if step["messages"]: | |
responses.append({ | |
'content': step["messages"][-1].content, | |
'type': step["messages"][-1].type | |
}) | |
return responses | |
except Exception as e: | |
logger.error(f"Query processing error: {str(e)}") | |
return [{'content': f"Query processing error: {str(e)}", 'type': 'error'}] | |
# === Initialize QA System === | |
qa_system = QASystem() | |
if qa_system.initialize_system(): | |
logger.info("QA System Initialized Successfully") | |
else: | |
raise RuntimeError("Failed to initialize QA System") | |
# === FastAPI Route === | |
async def query_api(query: str, user_id: str): # Pass user_id for session-specific memory | |
responses = qa_system.process_query(query, user_id) | |
return {"responses": responses} | |