File size: 7,900 Bytes
dbce286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b10d8a
dbce286
0b10d8a
dbce286
 
 
 
0b10d8a
dbce286
 
 
 
 
 
 
 
 
 
 
 
 
0b10d8a
dbce286
 
804a7ea
dbce286
 
 
 
0b10d8a
dbce286
 
0b10d8a
dbce286
 
 
 
 
 
 
 
0b10d8a
dbce286
 
 
 
 
 
0b10d8a
dbce286
0b10d8a
dbce286
 
 
 
 
 
 
 
0b10d8a
 
dbce286
0b10d8a
dbce286
0b10d8a
dbce286
 
 
 
 
0b10d8a
 
dbce286
 
0b10d8a
 
 
 
 
 
dbce286
 
 
0b10d8a
dbce286
0b10d8a
dbce286
 
 
b04e992
0b10d8a
 
dbce286
0b10d8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6400d6d
b04e992
0b10d8a
dbce286
804a7ea
0b10d8a
 
 
dbce286
b04e992
0b10d8a
 
 
 
 
b04e992
dbce286
 
0b10d8a
 
 
b04e992
0b10d8a
 
 
dbce286
0b10d8a
 
dbce286
 
 
 
 
 
 
0b10d8a
 
dbce286
 
 
 
 
0b10d8a
dbce286
 
 
 
 
 
 
 
 
 
 
0b10d8a
dbce286
 
 
 
 
 
0b10d8a
dbce286
0b10d8a
 
dbce286
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
from fastapi import FastAPI, HTTPException
import os
from typing import List, Dict
from dotenv import load_dotenv
import logging
from pathlib import Path

from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Qdrant as QdrantVectorStore
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain_groq import ChatGroq
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from qdrant_client.models import PointIdsList

from langgraph.graph import MessagesState, StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.messages import SystemMessage, HumanMessage

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Load environment variables
load_dotenv()
GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
GROQ_API_KEY = os.getenv('GROQ_API_KEY')

if not GOOGLE_API_KEY or not GROQ_API_KEY:
    raise ValueError("API keys not set in environment variables")

app = FastAPI()

class QASystem:
    def __init__(self):
        self.vector_store = None
        self.graph = None
        self.memory = MemorySaver()  # LangGraph memory saver for conversation history
        self.embeddings = None
        self.client = None
        self.pdf_dir = "pdfss"

    def load_pdf_documents(self):
        documents = []
        pdf_dir = Path(self.pdf_dir)

        if not pdf_dir.exists():
            raise FileNotFoundError(f"PDF directory not found: {self.pdf_dir}")

        for pdf_path in pdf_dir.glob("*.pdf"):
            try:
                loader = PyPDFLoader(str(pdf_path))
                documents.extend(loader.load())
                logger.info(f"Loaded PDF: {pdf_path}")
            except Exception as e:
                logger.error(f"Error loading PDF {pdf_path}: {str(e)}")

        text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
        split_docs = text_splitter.split_documents(documents)
        logger.info(f"Split documents into {len(split_docs)} chunks")
        return split_docs

    def initialize_system(self):
        try:
            # Qdrant setup
            self.client = QdrantClient(":memory:")

            try:
                self.client.get_collection("pdf_data")
            except Exception:
                self.client.create_collection(
                    collection_name="pdf_data",
                    vectors_config=VectorParams(size=768, distance=Distance.COSINE),
                )
                logger.info("Created new collection: pdf_data")

            # Embeddings and vector store
            self.embeddings = GoogleGenerativeAIEmbeddings(
                model="models/embedding-001", google_api_key=GOOGLE_API_KEY
            )

            self.vector_store = QdrantVectorStore(
                client=self.client,
                collection_name="pdf_data",
                embeddings=self.embeddings,
            )

            # Load and add documents
            documents = self.load_pdf_documents()
            if documents:
                points = self.client.scroll(collection_name="pdf_data", limit=100)[0]
                if points:
                    self.client.delete(
                        collection_name="pdf_data",
                        points_selector=PointIdsList(points=[p.id for p in points])
                    )
                self.vector_store.add_documents(documents)
                logger.info(f"Added {len(documents)} documents to vector store")

            # LLM setup
            llm = ChatGroq(
                model="llama3-8b-8192",
                api_key=GROQ_API_KEY,
                temperature=0.7
            )

            # Graph building
            graph_builder = StateGraph(MessagesState)

            # === TOOL NODE for context fetching from Qdrant ===
            def retrieve_documents(state: MessagesState):
                query = [m.content for m in state["messages"] if m.type == "human"][-1]
                results = self.vector_store.similarity_search(query, k=4)
                context = "\n\n".join([doc.page_content for doc in results])
                return {"messages": [SystemMessage(content=context, name="retrieval")]}  # as tool message

            # === GENERATOR NODE that uses full memory (chat history) ===
            def generate_response(state: MessagesState):
                # Get full history from memory
                thread_id = state["configurable"].get("thread_id", "default")
                history = self.memory.get_memory(thread_id).get("messages", [])
                
                logger.info(f"[Thread {thread_id}] History: {[m.content for m in history]}")
                
                # Add current turn messages
                all_messages = history + state["messages"]
                
                # Extract context from retrieved docs (tool messages)
                retrieved_docs = [m for m in all_messages if m.type == "tool"]
                context = ' '.join(m.content for m in retrieved_docs) if retrieved_docs else "mountain bicycle documentation knowledge"

                # Compose system prompt
                system_prompt = (
                    "You are an AI assistant embedded within the Interactive Electronic Technical Manual (IETM) for Mountain Cycles. "
                    "Your responses MUST be accurate, concise (5 sentences max). "
                    "If you don't know the answer, say 'I don't know based on available data.'\n\n"
                    f"Context:\n{context}"
                )

                final_messages = [SystemMessage(content=system_prompt)] + all_messages
                response = llm.invoke(final_messages)

                # Save updated chat to memory
                self.memory.save_checkpoint(thread_id, {"messages": all_messages + [response]})

                return {"messages": [response]}

            # Add graph nodes
            graph_builder.add_node("retrieval", retrieve_documents)
            graph_builder.add_node("generate", generate_response)

            # Graph edges
            graph_builder.set_entry_point("retrieval")
            graph_builder.add_edge("retrieval", "generate")
            graph_builder.add_edge("generate", END)

            # Compile graph with memory
            self.graph = graph_builder.compile(checkpointer=self.memory)
            return True

        except Exception as e:
            logger.error(f"System initialization error: {str(e)}")
            return False

    # === Query Processor with Memory ===
    def process_query(self, query: str, user_id: str) -> List[Dict[str, str]]:
        try:
            responses = []
            for step in self.graph.stream(
                {"messages": [HumanMessage(content=query)]},
                stream_mode="values",
                config={"configurable": {"thread_id": user_id}}  # thread ID for user memory
            ):
                if step["messages"]:
                    responses.append({
                        'content': step["messages"][-1].content,
                        'type': step["messages"][-1].type
                    })
            return responses
        except Exception as e:
            logger.error(f"Query processing error: {str(e)}")
            return [{'content': f"Query processing error: {str(e)}", 'type': 'error'}]

# === Initialize QA System ===
qa_system = QASystem()
if qa_system.initialize_system():
    logger.info("QA System Initialized Successfully")
else:
    raise RuntimeError("Failed to initialize QA System")

# === FastAPI Route ===
@app.post("/query")
async def query_api(query: str, user_id: str):  # Pass user_id for session-specific memory
    responses = qa_system.process_query(query, user_id)
    return {"responses": responses}