Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -15,7 +15,7 @@ from qdrant_client.http.models import Distance, VectorParams
|
|
15 |
from qdrant_client.models import PointIdsList
|
16 |
|
17 |
from langgraph.graph import MessagesState, StateGraph
|
18 |
-
from langchain_core.messages import SystemMessage, HumanMessage
|
19 |
from langgraph.prebuilt import ToolNode
|
20 |
from langgraph.graph import END
|
21 |
from langgraph.prebuilt import tools_condition
|
@@ -114,52 +114,76 @@ class QASystem:
|
|
114 |
|
115 |
graph_builder = StateGraph(MessagesState)
|
116 |
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
if
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
|
|
|
139 |
def generate(state: MessagesState):
|
140 |
-
|
141 |
-
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
system_prompt = (
|
145 |
"You are an AI assistant embedded within the Interactive Electronic Technical Manual (IETM) for Mountain Cycles. "
|
146 |
-
"
|
147 |
-
|
|
|
148 |
)
|
149 |
|
150 |
-
messages
|
151 |
-
|
152 |
-
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
154 |
response = llm.invoke(messages)
|
155 |
-
return {"messages": [response]}
|
156 |
-
|
157 |
|
158 |
-
|
|
|
159 |
graph_builder.add_node("generate", generate)
|
160 |
|
161 |
-
|
162 |
-
graph_builder.
|
|
|
163 |
graph_builder.add_edge("generate", END)
|
164 |
|
165 |
self.memory = MemorySaver()
|
@@ -173,16 +197,25 @@ class QASystem:
|
|
173 |
def process_query(self, query: str) -> List[Dict[str, str]]:
|
174 |
try:
|
175 |
responses = []
|
|
|
|
|
|
|
|
|
|
|
176 |
for step in self.graph.stream(
|
177 |
{"messages": [HumanMessage(content=query)]},
|
178 |
stream_mode="values",
|
179 |
-
config={"configurable": {"thread_id":
|
180 |
):
|
181 |
if step["messages"]:
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
|
|
|
|
|
|
|
|
186 |
return responses
|
187 |
except Exception as e:
|
188 |
logger.error(f"Query processing error: {str(e)}")
|
@@ -197,4 +230,4 @@ else:
|
|
197 |
@app.post("/query")
|
198 |
async def query_api(query: str):
|
199 |
responses = qa_system.process_query(query)
|
200 |
-
return {"responses": responses}
|
|
|
15 |
from qdrant_client.models import PointIdsList
|
16 |
|
17 |
from langgraph.graph import MessagesState, StateGraph
|
18 |
+
from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage
|
19 |
from langgraph.prebuilt import ToolNode
|
20 |
from langgraph.graph import END
|
21 |
from langgraph.prebuilt import tools_condition
|
|
|
114 |
|
115 |
graph_builder = StateGraph(MessagesState)
|
116 |
|
117 |
+
# Define a retrieval node that fetches relevant docs
|
118 |
+
def retrieve_docs(state: MessagesState):
|
119 |
+
# Get the most recent human message
|
120 |
+
human_messages = [m for m in state["messages"] if m.type == "human"]
|
121 |
+
if not human_messages:
|
122 |
+
return {"messages": state["messages"]}
|
123 |
+
|
124 |
+
user_query = human_messages[-1].content
|
125 |
+
logger.info(f"Retrieving documents for query: {user_query}")
|
126 |
+
|
127 |
+
# Query the vector store
|
128 |
+
try:
|
129 |
+
retrieved_docs = self.vector_store.similarity_search(user_query, k=3)
|
130 |
+
|
131 |
+
# Create tool messages for each retrieved document
|
132 |
+
tool_messages = []
|
133 |
+
for i, doc in enumerate(retrieved_docs):
|
134 |
+
tool_messages.append(
|
135 |
+
ToolMessage(
|
136 |
+
content=f"Document {i+1}: {doc.page_content}",
|
137 |
+
tool_call_id=f"retrieval_{i}"
|
138 |
+
)
|
139 |
+
)
|
140 |
+
|
141 |
+
logger.info(f"Retrieved {len(tool_messages)} relevant documents")
|
142 |
+
return {"messages": state["messages"] + tool_messages}
|
143 |
+
|
144 |
+
except Exception as e:
|
145 |
+
logger.error(f"Error retrieving documents: {str(e)}")
|
146 |
+
return {"messages": state["messages"]}
|
147 |
|
148 |
+
# Updated generate function that uses retrieved documents
|
149 |
def generate(state: MessagesState):
|
150 |
+
# Extract retrieved documents (tool messages)
|
151 |
+
tool_messages = [m for m in state["messages"] if m.type == "tool"]
|
152 |
+
|
153 |
+
# Collect context from retrieved documents
|
154 |
+
if tool_messages:
|
155 |
+
context = "\n".join([m.content for m in tool_messages])
|
156 |
+
logger.info(f"Using context from {len(tool_messages)} retrieved documents")
|
157 |
+
else:
|
158 |
+
context = "No specific mountain bicycle documentation available."
|
159 |
+
logger.info("No relevant documents retrieved, using default context")
|
160 |
|
161 |
system_prompt = (
|
162 |
"You are an AI assistant embedded within the Interactive Electronic Technical Manual (IETM) for Mountain Cycles. "
|
163 |
+
"Always provide accurate responses with references to provided data. "
|
164 |
+
"If the user query is not technical-specific, still respond from a IETM perspective."
|
165 |
+
f"\n\nContext from mountain bicycle documentation:\n{context}"
|
166 |
)
|
167 |
|
168 |
+
# Get all messages excluding tool messages to avoid redundancy
|
169 |
+
human_and_ai_messages = [m for m in state["messages"] if m.type != "tool"]
|
170 |
+
|
171 |
+
# Create the full message history for the LLM
|
172 |
+
messages = [SystemMessage(content=system_prompt)] + human_and_ai_messages
|
173 |
+
|
174 |
+
logger.info(f"Sending query to LLM with {len(messages)} messages")
|
175 |
+
|
176 |
+
# Generate the response
|
177 |
response = llm.invoke(messages)
|
178 |
+
return {"messages": state["messages"] + [response]}
|
|
|
179 |
|
180 |
+
# Add nodes to the graph
|
181 |
+
graph_builder.add_node("retrieve_docs", retrieve_docs)
|
182 |
graph_builder.add_node("generate", generate)
|
183 |
|
184 |
+
# Set the flow of the graph
|
185 |
+
graph_builder.set_entry_point("retrieve_docs")
|
186 |
+
graph_builder.add_edge("retrieve_docs", "generate")
|
187 |
graph_builder.add_edge("generate", END)
|
188 |
|
189 |
self.memory = MemorySaver()
|
|
|
197 |
def process_query(self, query: str) -> List[Dict[str, str]]:
|
198 |
try:
|
199 |
responses = []
|
200 |
+
|
201 |
+
# Use a unique thread_id for each conversation
|
202 |
+
thread_id = "abc123" # In production, generate a unique ID for each conversation
|
203 |
+
|
204 |
+
# Stream the responses
|
205 |
for step in self.graph.stream(
|
206 |
{"messages": [HumanMessage(content=query)]},
|
207 |
stream_mode="values",
|
208 |
+
config={"configurable": {"thread_id": thread_id}}
|
209 |
):
|
210 |
if step["messages"]:
|
211 |
+
# Only include AI messages in the response
|
212 |
+
ai_messages = [m for m in step["messages"] if m.type == "ai"]
|
213 |
+
if ai_messages:
|
214 |
+
responses.append({
|
215 |
+
'content': ai_messages[-1].content,
|
216 |
+
'type': ai_messages[-1].type
|
217 |
+
})
|
218 |
+
|
219 |
return responses
|
220 |
except Exception as e:
|
221 |
logger.error(f"Query processing error: {str(e)}")
|
|
|
230 |
@app.post("/query")
|
231 |
async def query_api(query: str):
|
232 |
responses = qa_system.process_query(query)
|
233 |
+
return {"responses": responses}
|