VishnuRamDebyez commited on
Commit
dc0ba34
·
verified ·
1 Parent(s): fe9fc71

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -34
app.py CHANGED
@@ -1,6 +1,6 @@
1
  from fastapi import FastAPI, HTTPException
2
  import os
3
- from typing import List, Dict
4
  from dotenv import load_dotenv
5
  import logging
6
  from pathlib import Path
@@ -21,9 +21,12 @@ from langgraph.graph import END
21
  from langgraph.prebuilt import tools_condition
22
  from langgraph.checkpoint.memory import MemorySaver
23
 
24
- logging.basicConfig(level=logging.INFO)
 
 
25
  logger = logging.getLogger(__name__)
26
 
 
27
  load_dotenv()
28
  GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
29
  GROQ_API_KEY = os.getenv('GROQ_API_KEY')
@@ -41,36 +44,73 @@ class QASystem:
41
  self.embeddings = None
42
  self.client = None
43
  self.pdf_dir = "pdfss"
 
44
 
45
  def load_pdf_documents(self):
 
46
  documents = []
47
  pdf_dir = Path(self.pdf_dir)
48
 
49
  if not pdf_dir.exists():
50
  raise FileNotFoundError(f"PDF directory not found: {self.pdf_dir}")
51
 
52
- for pdf_path in pdf_dir.glob("*.pdf"):
 
 
 
 
 
 
 
53
  try:
 
54
  loader = PyPDFLoader(str(pdf_path))
55
- documents.extend(loader.load())
56
- logger.info(f"Loaded PDF: {pdf_path}")
 
 
 
 
 
 
 
 
57
  except Exception as e:
58
  logger.error(f"Error loading PDF {pdf_path}: {str(e)}")
59
 
 
 
 
 
 
60
  text_splitter = RecursiveCharacterTextSplitter(
61
  chunk_size=1000,
62
- chunk_overlap=100
63
  )
64
  split_docs = text_splitter.split_documents(documents)
65
- logger.info(f"Split documents into {len(split_docs)} chunks")
 
 
 
 
 
 
 
66
  return split_docs
67
 
68
  def initialize_system(self):
 
69
  try:
 
 
 
70
  self.client = QdrantClient(":memory:")
 
71
 
 
72
  try:
73
- self.client.get_collection("pdf_data")
 
74
  except Exception:
75
  self.client.create_collection(
76
  collection_name="pdf_data",
@@ -78,22 +118,32 @@ class QASystem:
78
  )
79
  logger.info("Created new collection: pdf_data")
80
 
 
81
  self.embeddings = GoogleGenerativeAIEmbeddings(
82
  model="models/embedding-001",
83
  google_api_key=GOOGLE_API_KEY
84
  )
 
85
 
 
86
  self.vector_store = QdrantVectorStore(
87
  client=self.client,
88
  collection_name="pdf_data",
89
  embeddings=self.embeddings,
90
  )
 
91
 
 
92
  documents = self.load_pdf_documents()
 
 
 
 
93
  if documents:
94
  try:
95
- points = self.client.scroll(collection_name="pdf_data", limit=100)[0]
96
  if points:
 
97
  self.client.delete(
98
  collection_name="pdf_data",
99
  points_selector=PointIdsList(
@@ -103,65 +153,104 @@ class QASystem:
103
  except Exception as e:
104
  logger.error(f"Error clearing vectors: {str(e)}")
105
 
 
 
106
  self.vector_store.add_documents(documents)
107
- logger.info(f"Added {len(documents)} documents to vector store")
 
 
 
 
 
 
 
108
 
 
109
  llm = ChatGroq(
110
  model="llama3-8b-8192",
111
  api_key=GROQ_API_KEY,
112
  temperature=0.7
113
  )
 
114
 
 
115
  graph_builder = StateGraph(MessagesState)
 
116
 
117
- # Define a retrieval node that fetches relevant docs
 
 
118
  def retrieve_docs(state: MessagesState):
 
119
  # Get the most recent human message
120
  human_messages = [m for m in state["messages"] if m.type == "human"]
121
  if not human_messages:
 
122
  return {"messages": state["messages"]}
123
 
124
  user_query = human_messages[-1].content
125
- logger.info(f"Retrieving documents for query: {user_query}")
 
 
 
 
 
126
 
127
  # Query the vector store
128
  try:
129
- retrieved_docs = self.vector_store.similarity_search(user_query, k=3)
 
 
 
 
 
 
 
 
 
 
130
 
131
- # Create tool messages for each retrieved document
132
  tool_messages = []
133
  for i, doc in enumerate(retrieved_docs):
 
 
 
134
  tool_messages.append(
135
  ToolMessage(
136
- content=f"Document {i+1}: {doc.page_content}",
137
  tool_call_id=f"retrieval_{i}"
138
  )
139
  )
140
 
141
- logger.info(f"Retrieved {len(tool_messages)} relevant documents")
142
  return {"messages": state["messages"] + tool_messages}
143
 
144
  except Exception as e:
145
  logger.error(f"Error retrieving documents: {str(e)}")
146
  return {"messages": state["messages"]}
147
 
148
- # Updated generate function that uses retrieved documents
149
  def generate(state: MessagesState):
 
150
  # Extract retrieved documents (tool messages)
151
  tool_messages = [m for m in state["messages"] if m.type == "tool"]
152
 
153
  # Collect context from retrieved documents
154
  if tool_messages:
155
- context = "\n".join([m.content for m in tool_messages])
156
  logger.info(f"Using context from {len(tool_messages)} retrieved documents")
157
  else:
158
- context = "No specific mountain bicycle documentation available."
159
- logger.info("No relevant documents retrieved, using default context")
160
 
161
  system_prompt = (
162
  "You are an AI assistant embedded within the Interactive Electronic Technical Manual (IETM) for Mountain Cycles. "
163
- "Always provide accurate responses with references to provided data. "
164
- "If the user query is not technical-specific, still respond from a IETM perspective."
 
 
 
165
  f"\n\nContext from mountain bicycle documentation:\n{context}"
166
  )
167
 
@@ -174,8 +263,14 @@ class QASystem:
174
  logger.info(f"Sending query to LLM with {len(messages)} messages")
175
 
176
  # Generate the response
177
- response = llm.invoke(messages)
178
- return {"messages": state["messages"] + [response]}
 
 
 
 
 
 
179
 
180
  # Add nodes to the graph
181
  graph_builder.add_node("retrieve_docs", retrieve_docs)
@@ -186,22 +281,35 @@ class QASystem:
186
  graph_builder.add_edge("retrieve_docs", "generate")
187
  graph_builder.add_edge("generate", END)
188
 
 
189
  self.memory = MemorySaver()
190
  self.graph = graph_builder.compile(checkpointer=self.memory)
 
 
 
191
  return True
192
 
193
  except Exception as e:
194
  logger.error(f"System initialization error: {str(e)}")
 
195
  return False
196
 
197
- def process_query(self, query: str) -> Dict[str, str]:
198
  """Process a query and return a single final response"""
199
  try:
200
- # Generate a unique thread ID for production use
201
- # For simplicity, using a fixed ID here
 
 
 
 
 
 
 
 
202
  thread_id = "abc123"
203
 
204
- # Use invoke instead of stream to get only the final result
205
  final_state = self.graph.invoke(
206
  {"messages": [HumanMessage(content=query)]},
207
  config={"configurable": {"thread_id": thread_id}}
@@ -211,31 +319,35 @@ class QASystem:
211
  ai_messages = [m for m in final_state["messages"] if m.type == "ai"]
212
 
213
  if ai_messages:
 
214
  # Return only the last AI message
215
  return {
216
  'content': ai_messages[-1].content,
217
  'type': ai_messages[-1].type
218
  }
 
 
219
  return {
220
- 'content': "No response generated",
221
  'type': 'error'
222
  }
223
 
224
  except Exception as e:
225
  logger.error(f"Query processing error: {str(e)}")
226
  return {
227
- 'content': f"Query processing error: {str(e)}",
228
  'type': 'error'
229
  }
230
 
 
231
  qa_system = QASystem()
232
- if qa_system.initialize_system():
233
- logger.info("QA System Initialized Successfully")
234
- else:
235
- raise RuntimeError("Failed to initialize QA System")
236
 
237
  @app.post("/query")
238
  async def query_api(query: str):
239
  """API endpoint that returns a single response for a query"""
 
 
 
240
  response = qa_system.process_query(query)
241
  return {"response": response}
 
1
  from fastapi import FastAPI, HTTPException
2
  import os
3
+ from typing import List, Dict, Any
4
  from dotenv import load_dotenv
5
  import logging
6
  from pathlib import Path
 
21
  from langgraph.prebuilt import tools_condition
22
  from langgraph.checkpoint.memory import MemorySaver
23
 
24
+ # Configure logging
25
+ logging.basicConfig(level=logging.INFO,
26
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
27
  logger = logging.getLogger(__name__)
28
 
29
+ # Load environment variables
30
  load_dotenv()
31
  GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
32
  GROQ_API_KEY = os.getenv('GROQ_API_KEY')
 
44
  self.embeddings = None
45
  self.client = None
46
  self.pdf_dir = "pdfss"
47
+ self.is_initialized = False
48
 
49
  def load_pdf_documents(self):
50
+ """Load and process PDF documents from the pdf directory"""
51
  documents = []
52
  pdf_dir = Path(self.pdf_dir)
53
 
54
  if not pdf_dir.exists():
55
  raise FileNotFoundError(f"PDF directory not found: {self.pdf_dir}")
56
 
57
+ pdf_files = list(pdf_dir.glob("*.pdf"))
58
+ if not pdf_files:
59
+ logger.warning(f"No PDF files found in directory: {self.pdf_dir}")
60
+ return []
61
+
62
+ logger.info(f"Found {len(pdf_files)} PDF files to process")
63
+
64
+ for pdf_path in pdf_files:
65
  try:
66
+ logger.info(f"Processing PDF: {pdf_path}")
67
  loader = PyPDFLoader(str(pdf_path))
68
+ pdf_documents = loader.load()
69
+
70
+ # Add source information to metadata
71
+ for doc in pdf_documents:
72
+ if not hasattr(doc, 'metadata'):
73
+ doc.metadata = {}
74
+ doc.metadata['source'] = str(pdf_path.name)
75
+
76
+ documents.extend(pdf_documents)
77
+ logger.info(f"Loaded PDF: {pdf_path} - {len(pdf_documents)} pages/sections")
78
  except Exception as e:
79
  logger.error(f"Error loading PDF {pdf_path}: {str(e)}")
80
 
81
+ if not documents:
82
+ logger.warning("No documents were loaded from PDFs. Check the PDF directory and file formats.")
83
+ return []
84
+
85
+ # Split documents into smaller chunks for better retrieval
86
  text_splitter = RecursiveCharacterTextSplitter(
87
  chunk_size=1000,
88
+ chunk_overlap=200
89
  )
90
  split_docs = text_splitter.split_documents(documents)
91
+ logger.info(f"Split {len(documents)} documents into {len(split_docs)} chunks")
92
+
93
+ # Verify content of the first few chunks
94
+ for i, doc in enumerate(split_docs[:3]):
95
+ if i >= len(split_docs):
96
+ break
97
+ logger.info(f"Sample chunk {i+1} content preview: {doc.page_content[:100]}...")
98
+
99
  return split_docs
100
 
101
  def initialize_system(self):
102
+ """Initialize the RAG system with vector store and LLM"""
103
  try:
104
+ logger.info("Initializing QA System...")
105
+
106
+ # Initialize Qdrant client
107
  self.client = QdrantClient(":memory:")
108
+ logger.info("Qdrant client initialized (in-memory)")
109
 
110
+ # Create or get collection
111
  try:
112
+ collection_info = self.client.get_collection("pdf_data")
113
+ logger.info(f"Using existing collection: pdf_data")
114
  except Exception:
115
  self.client.create_collection(
116
  collection_name="pdf_data",
 
118
  )
119
  logger.info("Created new collection: pdf_data")
120
 
121
+ # Initialize embeddings model
122
  self.embeddings = GoogleGenerativeAIEmbeddings(
123
  model="models/embedding-001",
124
  google_api_key=GOOGLE_API_KEY
125
  )
126
+ logger.info("Google AI Embeddings initialized")
127
 
128
+ # Initialize vector store
129
  self.vector_store = QdrantVectorStore(
130
  client=self.client,
131
  collection_name="pdf_data",
132
  embeddings=self.embeddings,
133
  )
134
+ logger.info("Qdrant vector store initialized")
135
 
136
+ # Load documents
137
  documents = self.load_pdf_documents()
138
+ if not documents:
139
+ logger.warning("No documents loaded. The system will continue but may not provide relevant responses.")
140
+
141
+ # Clear existing vectors if any
142
  if documents:
143
  try:
144
+ points = self.client.scroll(collection_name="pdf_data", limit=1000)[0]
145
  if points:
146
+ logger.info(f"Clearing {len(points)} existing vectors from collection")
147
  self.client.delete(
148
  collection_name="pdf_data",
149
  points_selector=PointIdsList(
 
153
  except Exception as e:
154
  logger.error(f"Error clearing vectors: {str(e)}")
155
 
156
+ # Add documents to vector store
157
+ logger.info(f"Adding {len(documents)} documents to vector store")
158
  self.vector_store.add_documents(documents)
159
+ logger.info(f"Successfully added documents to vector store")
160
+
161
+ # Verify vector store has documents
162
+ try:
163
+ count = len(self.client.scroll(collection_name="pdf_data", limit=1)[0])
164
+ logger.info(f"Vector store contains points: {count > 0}")
165
+ except Exception as e:
166
+ logger.error(f"Error verifying vector store: {str(e)}")
167
 
168
+ # Initialize LLM
169
  llm = ChatGroq(
170
  model="llama3-8b-8192",
171
  api_key=GROQ_API_KEY,
172
  temperature=0.7
173
  )
174
+ logger.info("Groq LLM initialized")
175
 
176
+ # Create LangGraph
177
  graph_builder = StateGraph(MessagesState)
178
+ logger.info("Creating LangGraph for conversation flow")
179
 
180
+ # Define retrieval node (self reference for vector_store access)
181
+ vector_store_ref = self.vector_store
182
+
183
  def retrieve_docs(state: MessagesState):
184
+ """Node that retrieves relevant documents from the vector store"""
185
  # Get the most recent human message
186
  human_messages = [m for m in state["messages"] if m.type == "human"]
187
  if not human_messages:
188
+ logger.warning("No human messages found in state")
189
  return {"messages": state["messages"]}
190
 
191
  user_query = human_messages[-1].content
192
+ logger.info(f"Retrieving documents for query: '{user_query}'")
193
+
194
+ # Check if vector store exists
195
+ if not vector_store_ref:
196
+ logger.error("Vector store not initialized or empty")
197
+ return {"messages": state["messages"]}
198
 
199
  # Query the vector store
200
  try:
201
+ retrieved_docs = vector_store_ref.similarity_search(user_query, k=3)
202
+
203
+ if not retrieved_docs:
204
+ logger.warning(f"No documents retrieved for query: '{user_query}'")
205
+ return {"messages": state["messages"]}
206
+
207
+ # Log what was actually retrieved
208
+ for i, doc in enumerate(retrieved_docs):
209
+ source = doc.metadata.get('source', 'Unknown') if hasattr(doc, 'metadata') else 'Unknown'
210
+ content_preview = doc.page_content[:100] + "..." if len(doc.page_content) > 100 else doc.page_content
211
+ logger.info(f"Retrieved doc {i+1} from {source}, preview: {content_preview}")
212
 
213
+ # Create tool messages with more detailed content
214
  tool_messages = []
215
  for i, doc in enumerate(retrieved_docs):
216
+ # Include source information if available
217
+ source_info = f" (Source: {doc.metadata.get('source', 'Unknown')})" if hasattr(doc, 'metadata') else ""
218
+
219
  tool_messages.append(
220
  ToolMessage(
221
+ content=f"Document {i+1}{source_info}: {doc.page_content}",
222
  tool_call_id=f"retrieval_{i}"
223
  )
224
  )
225
 
226
+ logger.info(f"Created {len(tool_messages)} tool messages with retrieved content")
227
  return {"messages": state["messages"] + tool_messages}
228
 
229
  except Exception as e:
230
  logger.error(f"Error retrieving documents: {str(e)}")
231
  return {"messages": state["messages"]}
232
 
233
+ # Generate response using retrieved documents
234
  def generate(state: MessagesState):
235
+ """Node that generates a response using the LLM and retrieved documents"""
236
  # Extract retrieved documents (tool messages)
237
  tool_messages = [m for m in state["messages"] if m.type == "tool"]
238
 
239
  # Collect context from retrieved documents
240
  if tool_messages:
241
+ context = "\n\n".join([m.content for m in tool_messages])
242
  logger.info(f"Using context from {len(tool_messages)} retrieved documents")
243
  else:
244
+ context = "No specific mountain bicycle documentation available for this query."
245
+ logger.warning("No relevant documents retrieved, using default context")
246
 
247
  system_prompt = (
248
  "You are an AI assistant embedded within the Interactive Electronic Technical Manual (IETM) for Mountain Cycles. "
249
+ "Your primary role is to provide accurate technical information about mountain bicycles. "
250
+ "Always base your responses on the provided documentation. "
251
+ "If you don't find specific information in the provided context, clearly state that the information "
252
+ "is not available in the current documentation instead of making up details. "
253
+ "When responding, reference specific parts of the documentation."
254
  f"\n\nContext from mountain bicycle documentation:\n{context}"
255
  )
256
 
 
263
  logger.info(f"Sending query to LLM with {len(messages)} messages")
264
 
265
  # Generate the response
266
+ try:
267
+ response = llm.invoke(messages)
268
+ logger.info(f"LLM generated response successfully")
269
+ return {"messages": state["messages"] + [response]}
270
+ except Exception as e:
271
+ logger.error(f"Error generating response: {str(e)}")
272
+ error_message = SystemMessage(content=f"Error generating response: {str(e)}")
273
+ return {"messages": state["messages"] + [error_message]}
274
 
275
  # Add nodes to the graph
276
  graph_builder.add_node("retrieve_docs", retrieve_docs)
 
281
  graph_builder.add_edge("retrieve_docs", "generate")
282
  graph_builder.add_edge("generate", END)
283
 
284
+ # Initialize memory
285
  self.memory = MemorySaver()
286
  self.graph = graph_builder.compile(checkpointer=self.memory)
287
+ logger.info("Graph compiled successfully")
288
+
289
+ self.is_initialized = True
290
  return True
291
 
292
  except Exception as e:
293
  logger.error(f"System initialization error: {str(e)}")
294
+ self.is_initialized = False
295
  return False
296
 
297
+ def process_query(self, query: str) -> Dict[str, Any]:
298
  """Process a query and return a single final response"""
299
  try:
300
+ if not self.is_initialized:
301
+ logger.error("System not initialized. Cannot process query.")
302
+ return {
303
+ 'content': "Error: QA System not initialized properly",
304
+ 'type': 'error'
305
+ }
306
+
307
+ logger.info(f"Processing query: '{query}'")
308
+
309
+ # Generate a thread ID (use a more sophisticated method for production)
310
  thread_id = "abc123"
311
 
312
+ # Use invoke to get only the final result
313
  final_state = self.graph.invoke(
314
  {"messages": [HumanMessage(content=query)]},
315
  config={"configurable": {"thread_id": thread_id}}
 
319
  ai_messages = [m for m in final_state["messages"] if m.type == "ai"]
320
 
321
  if ai_messages:
322
+ logger.info("Successfully generated response")
323
  # Return only the last AI message
324
  return {
325
  'content': ai_messages[-1].content,
326
  'type': ai_messages[-1].type
327
  }
328
+
329
+ logger.warning("No AI message generated in response")
330
  return {
331
+ 'content': "No response could be generated for your query. Please try a different question.",
332
  'type': 'error'
333
  }
334
 
335
  except Exception as e:
336
  logger.error(f"Query processing error: {str(e)}")
337
  return {
338
+ 'content': f"Error processing your query: {str(e)}",
339
  'type': 'error'
340
  }
341
 
342
+ # Initialize the QA system
343
  qa_system = QASystem()
344
+ initialization_success = qa_system.initialize_system()
 
 
 
345
 
346
  @app.post("/query")
347
  async def query_api(query: str):
348
  """API endpoint that returns a single response for a query"""
349
+ if not qa_system.is_initialized:
350
+ raise HTTPException(status_code=500, detail="QA System not initialized properly")
351
+
352
  response = qa_system.process_query(query)
353
  return {"response": response}