VishnuRamDebyez commited on
Commit
1541f74
·
verified ·
1 Parent(s): 1a6b013

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -219
app.py CHANGED
@@ -1,6 +1,6 @@
1
  from fastapi import FastAPI, HTTPException
2
  import os
3
- from typing import List, Dict, Any
4
  from dotenv import load_dotenv
5
  import logging
6
  from pathlib import Path
@@ -15,18 +15,15 @@ from qdrant_client.http.models import Distance, VectorParams
15
  from qdrant_client.models import PointIdsList
16
 
17
  from langgraph.graph import MessagesState, StateGraph
18
- from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage
19
  from langgraph.prebuilt import ToolNode
20
  from langgraph.graph import END
21
  from langgraph.prebuilt import tools_condition
22
  from langgraph.checkpoint.memory import MemorySaver
23
 
24
- # Configure logging
25
- logging.basicConfig(level=logging.INFO,
26
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
27
  logger = logging.getLogger(__name__)
28
 
29
- # Load environment variables
30
  load_dotenv()
31
  GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
32
  GROQ_API_KEY = os.getenv('GROQ_API_KEY')
@@ -44,73 +41,36 @@ class QASystem:
44
  self.embeddings = None
45
  self.client = None
46
  self.pdf_dir = "pdfss"
47
- self.is_initialized = False
48
 
49
  def load_pdf_documents(self):
50
- """Load and process PDF documents from the pdf directory"""
51
  documents = []
52
  pdf_dir = Path(self.pdf_dir)
53
 
54
  if not pdf_dir.exists():
55
  raise FileNotFoundError(f"PDF directory not found: {self.pdf_dir}")
56
 
57
- pdf_files = list(pdf_dir.glob("*.pdf"))
58
- if not pdf_files:
59
- logger.warning(f"No PDF files found in directory: {self.pdf_dir}")
60
- return []
61
-
62
- logger.info(f"Found {len(pdf_files)} PDF files to process")
63
-
64
- for pdf_path in pdf_files:
65
  try:
66
- logger.info(f"Processing PDF: {pdf_path}")
67
  loader = PyPDFLoader(str(pdf_path))
68
- pdf_documents = loader.load()
69
-
70
- # Add source information to metadata
71
- for doc in pdf_documents:
72
- if not hasattr(doc, 'metadata'):
73
- doc.metadata = {}
74
- doc.metadata['source'] = str(pdf_path.name)
75
-
76
- documents.extend(pdf_documents)
77
- logger.info(f"Loaded PDF: {pdf_path} - {len(pdf_documents)} pages/sections")
78
  except Exception as e:
79
  logger.error(f"Error loading PDF {pdf_path}: {str(e)}")
80
 
81
- if not documents:
82
- logger.warning("No documents were loaded from PDFs. Check the PDF directory and file formats.")
83
- return []
84
-
85
- # Split documents into smaller chunks for better retrieval
86
  text_splitter = RecursiveCharacterTextSplitter(
87
  chunk_size=1000,
88
- chunk_overlap=200
89
  )
90
  split_docs = text_splitter.split_documents(documents)
91
- logger.info(f"Split {len(documents)} documents into {len(split_docs)} chunks")
92
-
93
- # Verify content of the first few chunks
94
- for i, doc in enumerate(split_docs[:3]):
95
- if i >= len(split_docs):
96
- break
97
- logger.info(f"Sample chunk {i+1} content preview: {doc.page_content[:100]}...")
98
-
99
  return split_docs
100
 
101
  def initialize_system(self):
102
- """Initialize the RAG system with vector store and LLM"""
103
  try:
104
- logger.info("Initializing QA System...")
105
-
106
- # Initialize Qdrant client
107
  self.client = QdrantClient(":memory:")
108
- logger.info("Qdrant client initialized (in-memory)")
109
 
110
- # Create or get collection
111
  try:
112
- collection_info = self.client.get_collection("pdf_data")
113
- logger.info(f"Using existing collection: pdf_data")
114
  except Exception:
115
  self.client.create_collection(
116
  collection_name="pdf_data",
@@ -118,32 +78,22 @@ class QASystem:
118
  )
119
  logger.info("Created new collection: pdf_data")
120
 
121
- # Initialize embeddings model
122
  self.embeddings = GoogleGenerativeAIEmbeddings(
123
  model="models/embedding-001",
124
  google_api_key=GOOGLE_API_KEY
125
  )
126
- logger.info("Google AI Embeddings initialized")
127
 
128
- # Initialize vector store
129
  self.vector_store = QdrantVectorStore(
130
  client=self.client,
131
  collection_name="pdf_data",
132
  embeddings=self.embeddings,
133
  )
134
- logger.info("Qdrant vector store initialized")
135
 
136
- # Load documents
137
  documents = self.load_pdf_documents()
138
- if not documents:
139
- logger.warning("No documents loaded. The system will continue but may not provide relevant responses.")
140
-
141
- # Clear existing vectors if any
142
  if documents:
143
  try:
144
- points = self.client.scroll(collection_name="pdf_data", limit=1000)[0]
145
  if points:
146
- logger.info(f"Clearing {len(points)} existing vectors from collection")
147
  self.client.delete(
148
  collection_name="pdf_data",
149
  points_selector=PointIdsList(
@@ -153,201 +103,98 @@ class QASystem:
153
  except Exception as e:
154
  logger.error(f"Error clearing vectors: {str(e)}")
155
 
156
- # Add documents to vector store
157
- logger.info(f"Adding {len(documents)} documents to vector store")
158
  self.vector_store.add_documents(documents)
159
- logger.info(f"Successfully added documents to vector store")
160
-
161
- # Verify vector store has documents
162
- try:
163
- count = len(self.client.scroll(collection_name="pdf_data", limit=1)[0])
164
- logger.info(f"Vector store contains points: {count > 0}")
165
- except Exception as e:
166
- logger.error(f"Error verifying vector store: {str(e)}")
167
 
168
- # Initialize LLM
169
  llm = ChatGroq(
170
  model="llama3-8b-8192",
171
  api_key=GROQ_API_KEY,
172
  temperature=0.7
173
  )
174
- logger.info("Groq LLM initialized")
175
 
176
- # Create LangGraph
177
  graph_builder = StateGraph(MessagesState)
178
- logger.info("Creating LangGraph for conversation flow")
179
 
180
- # Define retrieval node (self reference for vector_store access)
181
- vector_store_ref = self.vector_store
182
-
183
- def retrieve_docs(state: MessagesState):
184
- """Node that retrieves relevant documents from the vector store"""
185
- # Get the most recent human message
186
- human_messages = [m for m in state["messages"] if m.type == "human"]
187
- if not human_messages:
188
- logger.warning("No human messages found in state")
189
- return {"messages": state["messages"]}
190
-
191
- user_query = human_messages[-1].content
192
- logger.info(f"Retrieving documents for query: '{user_query}'")
193
-
194
- # Check if vector store exists
195
- if not vector_store_ref:
196
- logger.error("Vector store not initialized or empty")
197
- return {"messages": state["messages"]}
198
-
199
- # Query the vector store
200
- try:
201
- retrieved_docs = vector_store_ref.similarity_search(user_query, k=3)
202
-
203
- if not retrieved_docs:
204
- logger.warning(f"No documents retrieved for query: '{user_query}'")
205
- return {"messages": state["messages"]}
206
-
207
- # Log what was actually retrieved
208
- for i, doc in enumerate(retrieved_docs):
209
- source = doc.metadata.get('source', 'Unknown') if hasattr(doc, 'metadata') else 'Unknown'
210
- content_preview = doc.page_content[:100] + "..." if len(doc.page_content) > 100 else doc.page_content
211
- logger.info(f"Retrieved doc {i+1} from {source}, preview: {content_preview}")
212
-
213
- # Create tool messages with more detailed content
214
- tool_messages = []
215
- for i, doc in enumerate(retrieved_docs):
216
- # Include source information if available
217
- source_info = f" (Source: {doc.metadata.get('source', 'Unknown')})" if hasattr(doc, 'metadata') else ""
218
-
219
- tool_messages.append(
220
- ToolMessage(
221
- content=f"Document {i+1}{source_info}: {doc.page_content}",
222
- tool_call_id=f"retrieval_{i}"
223
- )
224
- )
225
-
226
- logger.info(f"Created {len(tool_messages)} tool messages with retrieved content")
227
- return {"messages": state["messages"] + tool_messages}
228
-
229
- except Exception as e:
230
- logger.error(f"Error retrieving documents: {str(e)}")
231
- return {"messages": state["messages"]}
232
 
233
- # Generate response using retrieved documents
234
  def generate(state: MessagesState):
235
- """Node that generates a response using the LLM and retrieved documents"""
236
- # Extract retrieved documents (tool messages)
237
- tool_messages = [m for m in state["messages"] if m.type == "tool"]
238
-
239
- # Collect context from retrieved documents
240
- if tool_messages:
241
- context = "\n\n".join([m.content for m in tool_messages])
242
- logger.info(f"Using context from {len(tool_messages)} retrieved documents")
243
- else:
244
- context = "No specific mountain bicycle documentation available for this query."
245
- logger.warning("No relevant documents retrieved, using default context")
246
 
247
  system_prompt = (
248
  "You are an AI assistant embedded within the Interactive Electronic Technical Manual (IETM) for Mountain Cycles. "
249
- "Your primary role is to provide accurate technical information about mountain bicycles. "
250
- "Always base your responses on the provided documentation. "
251
- "If you don't find specific information in the provided context, clearly state that the information "
252
- "is not available in the current documentation instead of making up details. "
253
- "When responding, reference specific parts of the documentation."
254
- f"\n\nContext from mountain bicycle documentation:\n{context}"
255
  )
256
 
257
- # Get all messages excluding tool messages to avoid redundancy
258
- human_and_ai_messages = [m for m in state["messages"] if m.type != "tool"]
259
-
260
- # Create the full message history for the LLM
261
- messages = [SystemMessage(content=system_prompt)] + human_and_ai_messages
262
-
263
- logger.info(f"Sending query to LLM with {len(messages)} messages")
264
-
265
- # Generate the response
266
- try:
267
- response = llm.invoke(messages)
268
- logger.info(f"LLM generated response successfully")
269
- return {"messages": state["messages"] + [response]}
270
- except Exception as e:
271
- logger.error(f"Error generating response: {str(e)}")
272
- error_message = SystemMessage(content=f"Error generating response: {str(e)}")
273
- return {"messages": state["messages"] + [error_message]}
274
 
275
- # Add nodes to the graph
276
- graph_builder.add_node("retrieve_docs", retrieve_docs)
277
  graph_builder.add_node("generate", generate)
278
 
279
- # Set the flow of the graph
280
- graph_builder.set_entry_point("retrieve_docs")
281
- graph_builder.add_edge("retrieve_docs", "generate")
282
  graph_builder.add_edge("generate", END)
283
 
284
- # Initialize memory
285
  self.memory = MemorySaver()
286
  self.graph = graph_builder.compile(checkpointer=self.memory)
287
- logger.info("Graph compiled successfully")
288
-
289
- self.is_initialized = True
290
  return True
291
 
292
  except Exception as e:
293
  logger.error(f"System initialization error: {str(e)}")
294
- self.is_initialized = False
295
  return False
296
 
297
- def process_query(self, query: str) -> Dict[str, Any]:
298
- """Process a query and return a single final response"""
299
  try:
300
- if not self.is_initialized:
301
- logger.error("System not initialized. Cannot process query.")
302
- return {
303
- 'content': "Error: QA System not initialized properly",
304
- 'type': 'error'
305
- }
306
-
307
- logger.info(f"Processing query: '{query}'")
308
-
309
- # Generate a thread ID (use a more sophisticated method for production)
310
- thread_id = "abc123"
311
-
312
- # Use invoke to get only the final result
313
- final_state = self.graph.invoke(
314
  {"messages": [HumanMessage(content=query)]},
315
- config={"configurable": {"thread_id": thread_id}}
316
- )
317
-
318
- # Extract only the last AI message from the final state
319
- ai_messages = [m for m in final_state["messages"] if m.type == "ai"]
320
-
321
- if ai_messages:
322
- logger.info("Successfully generated response")
323
- # Return only the last AI message
324
- return {
325
- 'content': ai_messages[-1].content,
326
- 'type': ai_messages[-1].type
327
- }
328
-
329
- logger.warning("No AI message generated in response")
330
- return {
331
- 'content': "No response could be generated for your query. Please try a different question.",
332
- 'type': 'error'
333
- }
334
-
335
  except Exception as e:
336
  logger.error(f"Query processing error: {str(e)}")
337
- return {
338
- 'content': f"Error processing your query: {str(e)}",
339
- 'type': 'error'
340
- }
341
 
342
- # Initialize the QA system
343
  qa_system = QASystem()
344
- initialization_success = qa_system.initialize_system()
 
 
 
345
 
346
  @app.post("/query")
347
  async def query_api(query: str):
348
- """API endpoint that returns a single response for a query"""
349
- if not qa_system.is_initialized:
350
- raise HTTPException(status_code=500, detail="QA System not initialized properly")
351
-
352
- response = qa_system.process_query(query)
353
- return {"response": response}
 
1
  from fastapi import FastAPI, HTTPException
2
  import os
3
+ from typing import List, Dict
4
  from dotenv import load_dotenv
5
  import logging
6
  from pathlib import Path
 
15
  from qdrant_client.models import PointIdsList
16
 
17
  from langgraph.graph import MessagesState, StateGraph
18
+ from langchain_core.messages import SystemMessage, HumanMessage
19
  from langgraph.prebuilt import ToolNode
20
  from langgraph.graph import END
21
  from langgraph.prebuilt import tools_condition
22
  from langgraph.checkpoint.memory import MemorySaver
23
 
24
+ logging.basicConfig(level=logging.INFO)
 
 
25
  logger = logging.getLogger(__name__)
26
 
 
27
  load_dotenv()
28
  GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
29
  GROQ_API_KEY = os.getenv('GROQ_API_KEY')
 
41
  self.embeddings = None
42
  self.client = None
43
  self.pdf_dir = "pdfss"
 
44
 
45
  def load_pdf_documents(self):
 
46
  documents = []
47
  pdf_dir = Path(self.pdf_dir)
48
 
49
  if not pdf_dir.exists():
50
  raise FileNotFoundError(f"PDF directory not found: {self.pdf_dir}")
51
 
52
+ for pdf_path in pdf_dir.glob("*.pdf"):
 
 
 
 
 
 
 
53
  try:
 
54
  loader = PyPDFLoader(str(pdf_path))
55
+ documents.extend(loader.load())
56
+ logger.info(f"Loaded PDF: {pdf_path}")
 
 
 
 
 
 
 
 
57
  except Exception as e:
58
  logger.error(f"Error loading PDF {pdf_path}: {str(e)}")
59
 
 
 
 
 
 
60
  text_splitter = RecursiveCharacterTextSplitter(
61
  chunk_size=1000,
62
+ chunk_overlap=100
63
  )
64
  split_docs = text_splitter.split_documents(documents)
65
+ logger.info(f"Split documents into {len(split_docs)} chunks")
 
 
 
 
 
 
 
66
  return split_docs
67
 
68
  def initialize_system(self):
 
69
  try:
 
 
 
70
  self.client = QdrantClient(":memory:")
 
71
 
 
72
  try:
73
+ self.client.get_collection("pdf_data")
 
74
  except Exception:
75
  self.client.create_collection(
76
  collection_name="pdf_data",
 
78
  )
79
  logger.info("Created new collection: pdf_data")
80
 
 
81
  self.embeddings = GoogleGenerativeAIEmbeddings(
82
  model="models/embedding-001",
83
  google_api_key=GOOGLE_API_KEY
84
  )
 
85
 
 
86
  self.vector_store = QdrantVectorStore(
87
  client=self.client,
88
  collection_name="pdf_data",
89
  embeddings=self.embeddings,
90
  )
 
91
 
 
92
  documents = self.load_pdf_documents()
 
 
 
 
93
  if documents:
94
  try:
95
+ points = self.client.scroll(collection_name="pdf_data", limit=100)[0]
96
  if points:
 
97
  self.client.delete(
98
  collection_name="pdf_data",
99
  points_selector=PointIdsList(
 
103
  except Exception as e:
104
  logger.error(f"Error clearing vectors: {str(e)}")
105
 
 
 
106
  self.vector_store.add_documents(documents)
107
+ logger.info(f"Added {len(documents)} documents to vector store")
 
 
 
 
 
 
 
108
 
 
109
  llm = ChatGroq(
110
  model="llama3-8b-8192",
111
  api_key=GROQ_API_KEY,
112
  temperature=0.7
113
  )
 
114
 
 
115
  graph_builder = StateGraph(MessagesState)
 
116
 
117
+ def query_or_respond(state: MessagesState):
118
+ retrieved_docs = [m for m in state["messages"] if m.type == "tool"]
119
+
120
+ if retrieved_docs:
121
+ context = ' '.join(m.content for m in retrieved_docs)
122
+ else:
123
+ context = "mountain bicycle documentation knowledge"
124
+
125
+ system_prompt = (
126
+ "You are an AI assistant embedded within the Interactive Electronic Technical Manual (IETM) for Mountain Cycles.. "
127
+ "Always provide accurate responses with references to provided data. "
128
+ "If the user query is not technical-specific, still respond from a IETM perspective."
129
+ f"\n\nContext:\n{context}"
130
+ )
131
+
132
+ messages = [SystemMessage(content=system_prompt)] + state["messages"]
133
+
134
+ logger.info(f"Sending to LLM: {[m.content for m in messages]}") # Debugging log
135
+
136
+ response = llm.invoke(messages)
137
+ return {"messages": [response]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
 
139
  def generate(state: MessagesState):
140
+ retrieved_docs = [m for m in reversed(state["messages"]) if m.type == "tool"][::-1]
141
+
142
+ context = ' '.join(m.content for m in retrieved_docs) if retrieved_docs else "mountain bicycle documentation knowledge"
 
 
 
 
 
 
 
 
143
 
144
  system_prompt = (
145
  "You are an AI assistant embedded within the Interactive Electronic Technical Manual (IETM) for Mountain Cycles. "
146
+ "Your responses MUST be accurate, concise (5 sentences max)."
147
+ f"\n\nContext:\n{context}"
 
 
 
 
148
  )
149
 
150
+ messages = [SystemMessage(content=system_prompt)] + state["messages"]
151
+
152
+ logger.info(f"Sending to LLM: {[m.content for m in messages]}") # Debugging log
153
+
154
+ response = llm.invoke(messages)
155
+ return {"messages": [response]}
 
 
 
 
 
 
 
 
 
 
 
156
 
157
+
158
+ graph_builder.add_node("query_or_respond", query_or_respond)
159
  graph_builder.add_node("generate", generate)
160
 
161
+ graph_builder.set_entry_point("query_or_respond")
162
+ graph_builder.add_edge("query_or_respond", "generate")
 
163
  graph_builder.add_edge("generate", END)
164
 
 
165
  self.memory = MemorySaver()
166
  self.graph = graph_builder.compile(checkpointer=self.memory)
 
 
 
167
  return True
168
 
169
  except Exception as e:
170
  logger.error(f"System initialization error: {str(e)}")
 
171
  return False
172
 
173
+ def process_query(self, query: str) -> List[Dict[str, str]]:
 
174
  try:
175
+ responses = []
176
+ for step in self.graph.stream(
 
 
 
 
 
 
 
 
 
 
 
 
177
  {"messages": [HumanMessage(content=query)]},
178
+ stream_mode="values",
179
+ config={"configurable": {"thread_id": "abc123"}}
180
+ ):
181
+ if step["messages"]:
182
+ responses.append({
183
+ 'content': step["messages"][-1].content,
184
+ 'type': step["messages"][-1].type
185
+ })
186
+ return responses
 
 
 
 
 
 
 
 
 
 
 
187
  except Exception as e:
188
  logger.error(f"Query processing error: {str(e)}")
189
+ return [{'content': f"Query processing error: {str(e)}", 'type': 'error'}]
 
 
 
190
 
 
191
  qa_system = QASystem()
192
+ if qa_system.initialize_system():
193
+ logger.info("QA System Initialized Successfully")
194
+ else:
195
+ raise RuntimeError("Failed to initialize QA System")
196
 
197
  @app.post("/query")
198
  async def query_api(query: str):
199
+ responses = qa_system.process_query(query)
200
+ return {"responses": responses}