VishnuRamDebyez commited on
Commit
d05ce95
·
verified ·
1 Parent(s): ab11098

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -42
app.py CHANGED
@@ -15,7 +15,7 @@ from qdrant_client.http.models import Distance, VectorParams
15
  from qdrant_client.models import PointIdsList
16
 
17
  from langgraph.graph import MessagesState, StateGraph
18
- from langchain_core.messages import SystemMessage, HumanMessage
19
  from langgraph.prebuilt import ToolNode
20
  from langgraph.graph import END
21
  from langgraph.prebuilt import tools_condition
@@ -114,52 +114,76 @@ class QASystem:
114
 
115
  graph_builder = StateGraph(MessagesState)
116
 
117
- def query_or_respond(state: MessagesState):
118
- retrieved_docs = [m for m in state["messages"] if m.type == "tool"]
119
-
120
- if retrieved_docs:
121
- context = ' '.join(m.content for m in retrieved_docs)
122
- else:
123
- context = "mountain bicycle documentation knowledge"
124
-
125
- system_prompt = (
126
- "You are an AI assistant embedded within the Interactive Electronic Technical Manual (IETM) for Mountain Cycles.. "
127
- "Always provide accurate responses with references to provided data. "
128
- "If the user query is not technical-specific, still respond from a IETM perspective."
129
- f"\n\nContext:\n{context}"
130
- )
131
-
132
- messages = [SystemMessage(content=system_prompt)] + state["messages"]
133
-
134
- logger.info(f"Sending to LLM: {[m.content for m in messages]}") # Debugging log
135
-
136
- response = llm.invoke(messages)
137
- return {"messages": [response]}
 
 
 
 
 
 
 
 
 
138
 
 
139
  def generate(state: MessagesState):
140
- retrieved_docs = [m for m in reversed(state["messages"]) if m.type == "tool"][::-1]
141
-
142
- context = ' '.join(m.content for m in retrieved_docs) if retrieved_docs else "mountain bicycle documentation knowledge"
 
 
 
 
 
 
 
143
 
144
  system_prompt = (
145
  "You are an AI assistant embedded within the Interactive Electronic Technical Manual (IETM) for Mountain Cycles. "
146
- "Your responses MUST be accurate, concise (5 sentences max)."
147
- f"\n\nContext:\n{context}"
 
148
  )
149
 
150
- messages = [SystemMessage(content=system_prompt)] + state["messages"]
151
-
152
- logger.info(f"Sending to LLM: {[m.content for m in messages]}") # Debugging log
153
-
 
 
 
 
 
154
  response = llm.invoke(messages)
155
- return {"messages": [response]}
156
-
157
 
158
- graph_builder.add_node("query_or_respond", query_or_respond)
 
159
  graph_builder.add_node("generate", generate)
160
 
161
- graph_builder.set_entry_point("query_or_respond")
162
- graph_builder.add_edge("query_or_respond", "generate")
 
163
  graph_builder.add_edge("generate", END)
164
 
165
  self.memory = MemorySaver()
@@ -173,16 +197,25 @@ class QASystem:
173
  def process_query(self, query: str) -> List[Dict[str, str]]:
174
  try:
175
  responses = []
 
 
 
 
 
176
  for step in self.graph.stream(
177
  {"messages": [HumanMessage(content=query)]},
178
  stream_mode="values",
179
- config={"configurable": {"thread_id": "abc123"}}
180
  ):
181
  if step["messages"]:
182
- responses.append({
183
- 'content': step["messages"][-1].content,
184
- 'type': step["messages"][-1].type
185
- })
 
 
 
 
186
  return responses
187
  except Exception as e:
188
  logger.error(f"Query processing error: {str(e)}")
@@ -197,4 +230,4 @@ else:
197
  @app.post("/query")
198
  async def query_api(query: str):
199
  responses = qa_system.process_query(query)
200
- return {"responses": responses}
 
15
  from qdrant_client.models import PointIdsList
16
 
17
  from langgraph.graph import MessagesState, StateGraph
18
+ from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage
19
  from langgraph.prebuilt import ToolNode
20
  from langgraph.graph import END
21
  from langgraph.prebuilt import tools_condition
 
114
 
115
  graph_builder = StateGraph(MessagesState)
116
 
117
+ # Define a retrieval node that fetches relevant docs
118
+ def retrieve_docs(state: MessagesState):
119
+ # Get the most recent human message
120
+ human_messages = [m for m in state["messages"] if m.type == "human"]
121
+ if not human_messages:
122
+ return {"messages": state["messages"]}
123
+
124
+ user_query = human_messages[-1].content
125
+ logger.info(f"Retrieving documents for query: {user_query}")
126
+
127
+ # Query the vector store
128
+ try:
129
+ retrieved_docs = self.vector_store.similarity_search(user_query, k=3)
130
+
131
+ # Create tool messages for each retrieved document
132
+ tool_messages = []
133
+ for i, doc in enumerate(retrieved_docs):
134
+ tool_messages.append(
135
+ ToolMessage(
136
+ content=f"Document {i+1}: {doc.page_content}",
137
+ tool_call_id=f"retrieval_{i}"
138
+ )
139
+ )
140
+
141
+ logger.info(f"Retrieved {len(tool_messages)} relevant documents")
142
+ return {"messages": state["messages"] + tool_messages}
143
+
144
+ except Exception as e:
145
+ logger.error(f"Error retrieving documents: {str(e)}")
146
+ return {"messages": state["messages"]}
147
 
148
+ # Updated generate function that uses retrieved documents
149
  def generate(state: MessagesState):
150
+ # Extract retrieved documents (tool messages)
151
+ tool_messages = [m for m in state["messages"] if m.type == "tool"]
152
+
153
+ # Collect context from retrieved documents
154
+ if tool_messages:
155
+ context = "\n".join([m.content for m in tool_messages])
156
+ logger.info(f"Using context from {len(tool_messages)} retrieved documents")
157
+ else:
158
+ context = "No specific mountain bicycle documentation available."
159
+ logger.info("No relevant documents retrieved, using default context")
160
 
161
  system_prompt = (
162
  "You are an AI assistant embedded within the Interactive Electronic Technical Manual (IETM) for Mountain Cycles. "
163
+ "Always provide accurate responses with references to provided data. "
164
+ "If the user query is not technical-specific, still respond from a IETM perspective."
165
+ f"\n\nContext from mountain bicycle documentation:\n{context}"
166
  )
167
 
168
+ # Get all messages excluding tool messages to avoid redundancy
169
+ human_and_ai_messages = [m for m in state["messages"] if m.type != "tool"]
170
+
171
+ # Create the full message history for the LLM
172
+ messages = [SystemMessage(content=system_prompt)] + human_and_ai_messages
173
+
174
+ logger.info(f"Sending query to LLM with {len(messages)} messages")
175
+
176
+ # Generate the response
177
  response = llm.invoke(messages)
178
+ return {"messages": state["messages"] + [response]}
 
179
 
180
+ # Add nodes to the graph
181
+ graph_builder.add_node("retrieve_docs", retrieve_docs)
182
  graph_builder.add_node("generate", generate)
183
 
184
+ # Set the flow of the graph
185
+ graph_builder.set_entry_point("retrieve_docs")
186
+ graph_builder.add_edge("retrieve_docs", "generate")
187
  graph_builder.add_edge("generate", END)
188
 
189
  self.memory = MemorySaver()
 
197
  def process_query(self, query: str) -> List[Dict[str, str]]:
198
  try:
199
  responses = []
200
+
201
+ # Use a unique thread_id for each conversation
202
+ thread_id = "abc123" # In production, generate a unique ID for each conversation
203
+
204
+ # Stream the responses
205
  for step in self.graph.stream(
206
  {"messages": [HumanMessage(content=query)]},
207
  stream_mode="values",
208
+ config={"configurable": {"thread_id": thread_id}}
209
  ):
210
  if step["messages"]:
211
+ # Only include AI messages in the response
212
+ ai_messages = [m for m in step["messages"] if m.type == "ai"]
213
+ if ai_messages:
214
+ responses.append({
215
+ 'content': ai_messages[-1].content,
216
+ 'type': ai_messages[-1].type
217
+ })
218
+
219
  return responses
220
  except Exception as e:
221
  logger.error(f"Query processing error: {str(e)}")
 
230
  @app.post("/query")
231
  async def query_api(query: str):
232
  responses = qa_system.process_query(query)
233
+ return {"responses": responses}