VishnuRamDebyez commited on
Commit
0b10d8a
·
verified ·
1 Parent(s): 6400d6d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -74
app.py CHANGED
@@ -14,16 +14,14 @@ from qdrant_client import QdrantClient
14
  from qdrant_client.http.models import Distance, VectorParams
15
  from qdrant_client.models import PointIdsList
16
 
17
- from langgraph.graph import MessagesState, StateGraph
18
- from langchain_core.messages import SystemMessage, HumanMessage
19
- from langgraph.prebuilt import ToolNode
20
- from langgraph.graph import END
21
- from langgraph.prebuilt import tools_condition
22
  from langgraph.checkpoint.memory import MemorySaver
 
23
 
24
  logging.basicConfig(level=logging.INFO)
25
  logger = logging.getLogger(__name__)
26
 
 
27
  load_dotenv()
28
  GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
29
  GROQ_API_KEY = os.getenv('GROQ_API_KEY')
@@ -37,7 +35,7 @@ class QASystem:
37
  def __init__(self):
38
  self.vector_store = None
39
  self.graph = None
40
- self.memory = None
41
  self.embeddings = None
42
  self.client = None
43
  self.pdf_dir = "pdfss"
@@ -45,10 +43,10 @@ class QASystem:
45
  def load_pdf_documents(self):
46
  documents = []
47
  pdf_dir = Path(self.pdf_dir)
48
-
49
  if not pdf_dir.exists():
50
  raise FileNotFoundError(f"PDF directory not found: {self.pdf_dir}")
51
-
52
  for pdf_path in pdf_dir.glob("*.pdf"):
53
  try:
54
  loader = PyPDFLoader(str(pdf_path))
@@ -57,18 +55,16 @@ class QASystem:
57
  except Exception as e:
58
  logger.error(f"Error loading PDF {pdf_path}: {str(e)}")
59
 
60
- text_splitter = RecursiveCharacterTextSplitter(
61
- chunk_size=1000,
62
- chunk_overlap=100
63
- )
64
  split_docs = text_splitter.split_documents(documents)
65
  logger.info(f"Split documents into {len(split_docs)} chunks")
66
  return split_docs
67
 
68
  def initialize_system(self):
69
  try:
 
70
  self.client = QdrantClient(":memory:")
71
-
72
  try:
73
  self.client.get_collection("pdf_data")
74
  except Exception:
@@ -77,92 +73,88 @@ class QASystem:
77
  vectors_config=VectorParams(size=768, distance=Distance.COSINE),
78
  )
79
  logger.info("Created new collection: pdf_data")
80
-
 
81
  self.embeddings = GoogleGenerativeAIEmbeddings(
82
- model="models/embedding-001",
83
- google_api_key=GOOGLE_API_KEY
84
  )
85
-
86
  self.vector_store = QdrantVectorStore(
87
  client=self.client,
88
  collection_name="pdf_data",
89
  embeddings=self.embeddings,
90
  )
91
-
 
92
  documents = self.load_pdf_documents()
93
  if documents:
94
- try:
95
- points = self.client.scroll(collection_name="pdf_data", limit=100)[0]
96
- if points:
97
- self.client.delete(
98
- collection_name="pdf_data",
99
- points_selector=PointIdsList(
100
- points=[p.id for p in points]
101
- )
102
- )
103
- except Exception as e:
104
- logger.error(f"Error clearing vectors: {str(e)}")
105
-
106
  self.vector_store.add_documents(documents)
107
  logger.info(f"Added {len(documents)} documents to vector store")
108
 
 
109
  llm = ChatGroq(
110
- model="llama3-8b-8192",
111
  api_key=GROQ_API_KEY,
112
  temperature=0.7
113
  )
114
-
115
- graph_builder = StateGraph(MessagesState)
116
-
117
- def query_or_respond(state: MessagesState):
118
- retrieved_docs = [m for m in state["messages"] if m.type == "tool"]
119
-
120
- if retrieved_docs:
121
- context = ' '.join(m.content for m in retrieved_docs)
122
- else:
123
- context = "mountain bicycle documentation knowledge"
124
 
125
- system_prompt = (
126
- "You are an AI assistant embedded within the Interactive Electronic Technical Manual (IETM) for Mountain Cycles.. "
127
- "Always provide accurate responses with references to provided data. "
128
- "If the user query is not technical-specific, still respond from a IETM perspective."
129
- f"\n\nContext:\n{context}"
130
- )
131
-
132
- messages = [SystemMessage(content=system_prompt)] + state["messages"]
133
-
134
- logger.info(f"Sending to LLM: {[m.content for m in messages]}") # Debugging log
135
-
136
- response = llm.invoke(messages)
137
- return {"messages": [response]}
138
 
139
- def generate(state: MessagesState):
140
- retrieved_docs = [m for m in reversed(state["messages"]) if m.type == "tool"][::-1]
141
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  context = ' '.join(m.content for m in retrieved_docs) if retrieved_docs else "mountain bicycle documentation knowledge"
143
 
 
144
  system_prompt = (
145
  "You are an AI assistant embedded within the Interactive Electronic Technical Manual (IETM) for Mountain Cycles. "
146
- "Your responses MUST be accurate, concise (5 sentences max)."
147
- f"\n\nContext:\n{context}"
 
148
  )
149
 
150
- messages = [SystemMessage(content=system_prompt)] + state["messages"]
151
-
152
- logger.info(f"Sending to LLM: {[m.content for m in messages]}") # Debugging log
 
 
153
 
154
- response = llm.invoke(messages)
155
  return {"messages": [response]}
156
 
 
 
 
157
 
158
- graph_builder.add_node("query_or_respond", query_or_respond)
159
- graph_builder.add_node("generate", generate)
160
-
161
- graph_builder.set_entry_point("query_or_respond")
162
- graph_builder.add_edge("query_or_respond", "generate")
163
  graph_builder.add_edge("generate", END)
164
-
165
- self.memory = MemorySaver()
166
  self.graph = graph_builder.compile(checkpointer=self.memory)
167
  return True
168
 
@@ -170,13 +162,14 @@ class QASystem:
170
  logger.error(f"System initialization error: {str(e)}")
171
  return False
172
 
173
- def process_query(self, query: str) -> List[Dict[str, str]]:
 
174
  try:
175
  responses = []
176
  for step in self.graph.stream(
177
  {"messages": [HumanMessage(content=query)]},
178
  stream_mode="values",
179
- config={"configurable": {"thread_id": "abc123"}}
180
  ):
181
  if step["messages"]:
182
  responses.append({
@@ -188,13 +181,15 @@ class QASystem:
188
  logger.error(f"Query processing error: {str(e)}")
189
  return [{'content': f"Query processing error: {str(e)}", 'type': 'error'}]
190
 
 
191
  qa_system = QASystem()
192
  if qa_system.initialize_system():
193
  logger.info("QA System Initialized Successfully")
194
  else:
195
  raise RuntimeError("Failed to initialize QA System")
196
 
 
197
  @app.post("/query")
198
- async def query_api(query: str):
199
- responses = qa_system.process_query(query)
200
  return {"responses": responses}
 
14
  from qdrant_client.http.models import Distance, VectorParams
15
  from qdrant_client.models import PointIdsList
16
 
17
+ from langgraph.graph import MessagesState, StateGraph, END
 
 
 
 
18
  from langgraph.checkpoint.memory import MemorySaver
19
+ from langchain_core.messages import SystemMessage, HumanMessage
20
 
21
  logging.basicConfig(level=logging.INFO)
22
  logger = logging.getLogger(__name__)
23
 
24
+ # Load environment variables
25
  load_dotenv()
26
  GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
27
  GROQ_API_KEY = os.getenv('GROQ_API_KEY')
 
35
  def __init__(self):
36
  self.vector_store = None
37
  self.graph = None
38
+ self.memory = MemorySaver() # LangGraph memory saver for conversation history
39
  self.embeddings = None
40
  self.client = None
41
  self.pdf_dir = "pdfss"
 
43
  def load_pdf_documents(self):
44
  documents = []
45
  pdf_dir = Path(self.pdf_dir)
46
+
47
  if not pdf_dir.exists():
48
  raise FileNotFoundError(f"PDF directory not found: {self.pdf_dir}")
49
+
50
  for pdf_path in pdf_dir.glob("*.pdf"):
51
  try:
52
  loader = PyPDFLoader(str(pdf_path))
 
55
  except Exception as e:
56
  logger.error(f"Error loading PDF {pdf_path}: {str(e)}")
57
 
58
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
 
 
 
59
  split_docs = text_splitter.split_documents(documents)
60
  logger.info(f"Split documents into {len(split_docs)} chunks")
61
  return split_docs
62
 
63
  def initialize_system(self):
64
  try:
65
+ # Qdrant setup
66
  self.client = QdrantClient(":memory:")
67
+
68
  try:
69
  self.client.get_collection("pdf_data")
70
  except Exception:
 
73
  vectors_config=VectorParams(size=768, distance=Distance.COSINE),
74
  )
75
  logger.info("Created new collection: pdf_data")
76
+
77
+ # Embeddings and vector store
78
  self.embeddings = GoogleGenerativeAIEmbeddings(
79
+ model="models/embedding-001", google_api_key=GOOGLE_API_KEY
 
80
  )
81
+
82
  self.vector_store = QdrantVectorStore(
83
  client=self.client,
84
  collection_name="pdf_data",
85
  embeddings=self.embeddings,
86
  )
87
+
88
+ # Load and add documents
89
  documents = self.load_pdf_documents()
90
  if documents:
91
+ points = self.client.scroll(collection_name="pdf_data", limit=100)[0]
92
+ if points:
93
+ self.client.delete(
94
+ collection_name="pdf_data",
95
+ points_selector=PointIdsList(points=[p.id for p in points])
96
+ )
 
 
 
 
 
 
97
  self.vector_store.add_documents(documents)
98
  logger.info(f"Added {len(documents)} documents to vector store")
99
 
100
+ # LLM setup
101
  llm = ChatGroq(
102
+ model="llama3-8b-8192",
103
  api_key=GROQ_API_KEY,
104
  temperature=0.7
105
  )
 
 
 
 
 
 
 
 
 
 
106
 
107
+ # Graph building
108
+ graph_builder = StateGraph(MessagesState)
 
 
 
 
 
 
 
 
 
 
 
109
 
110
+ # === TOOL NODE for context fetching from Qdrant ===
111
+ def retrieve_documents(state: MessagesState):
112
+ query = [m.content for m in state["messages"] if m.type == "human"][-1]
113
+ results = self.vector_store.similarity_search(query, k=4)
114
+ context = "\n\n".join([doc.page_content for doc in results])
115
+ return {"messages": [SystemMessage(content=context, name="retrieval")]} # as tool message
116
+
117
+ # === GENERATOR NODE that uses full memory (chat history) ===
118
+ def generate_response(state: MessagesState):
119
+ # Get full history from memory
120
+ thread_id = state["configurable"].get("thread_id", "default")
121
+ history = self.memory.get_memory(thread_id).get("messages", [])
122
+
123
+ logger.info(f"[Thread {thread_id}] History: {[m.content for m in history]}")
124
+
125
+ # Add current turn messages
126
+ all_messages = history + state["messages"]
127
+
128
+ # Extract context from retrieved docs (tool messages)
129
+ retrieved_docs = [m for m in all_messages if m.type == "tool"]
130
  context = ' '.join(m.content for m in retrieved_docs) if retrieved_docs else "mountain bicycle documentation knowledge"
131
 
132
+ # Compose system prompt
133
  system_prompt = (
134
  "You are an AI assistant embedded within the Interactive Electronic Technical Manual (IETM) for Mountain Cycles. "
135
+ "Your responses MUST be accurate, concise (5 sentences max). "
136
+ "If you don't know the answer, say 'I don't know based on available data.'\n\n"
137
+ f"Context:\n{context}"
138
  )
139
 
140
+ final_messages = [SystemMessage(content=system_prompt)] + all_messages
141
+ response = llm.invoke(final_messages)
142
+
143
+ # Save updated chat to memory
144
+ self.memory.save_checkpoint(thread_id, {"messages": all_messages + [response]})
145
 
 
146
  return {"messages": [response]}
147
 
148
+ # Add graph nodes
149
+ graph_builder.add_node("retrieval", retrieve_documents)
150
+ graph_builder.add_node("generate", generate_response)
151
 
152
+ # Graph edges
153
+ graph_builder.set_entry_point("retrieval")
154
+ graph_builder.add_edge("retrieval", "generate")
 
 
155
  graph_builder.add_edge("generate", END)
156
+
157
+ # Compile graph with memory
158
  self.graph = graph_builder.compile(checkpointer=self.memory)
159
  return True
160
 
 
162
  logger.error(f"System initialization error: {str(e)}")
163
  return False
164
 
165
+ # === Query Processor with Memory ===
166
+ def process_query(self, query: str, user_id: str) -> List[Dict[str, str]]:
167
  try:
168
  responses = []
169
  for step in self.graph.stream(
170
  {"messages": [HumanMessage(content=query)]},
171
  stream_mode="values",
172
+ config={"configurable": {"thread_id": user_id}} # thread ID for user memory
173
  ):
174
  if step["messages"]:
175
  responses.append({
 
181
  logger.error(f"Query processing error: {str(e)}")
182
  return [{'content': f"Query processing error: {str(e)}", 'type': 'error'}]
183
 
184
+ # === Initialize QA System ===
185
  qa_system = QASystem()
186
  if qa_system.initialize_system():
187
  logger.info("QA System Initialized Successfully")
188
  else:
189
  raise RuntimeError("Failed to initialize QA System")
190
 
191
+ # === FastAPI Route ===
192
  @app.post("/query")
193
+ async def query_api(query: str, user_id: str): # Pass user_id for session-specific memory
194
+ responses = qa_system.process_query(query, user_id)
195
  return {"responses": responses}