Arpit-Bansal commited on
Commit
6a04a88
·
1 Parent(s): 28e1a34

Langgraph-Agent V0.5

Browse files
Files changed (5) hide show
  1. .gitignore +5 -1
  2. agent.py +11 -19
  3. fetch.py +1 -1
  4. langgraph_agent.py +200 -0
  5. main.py +12 -6
.gitignore CHANGED
@@ -2,4 +2,8 @@ agent_advance.py
2
  *.ipynb
3
  __pycache__/
4
  hackathon-healthcare-solutions-9e6f46d0a21e.json
5
- venv/
 
 
 
 
 
2
  *.ipynb
3
  __pycache__/
4
  hackathon-healthcare-solutions-9e6f46d0a21e.json
5
+ venv/
6
+ .vscode/
7
+ .env
8
+ test.py
9
+ app.py
agent.py CHANGED
@@ -3,7 +3,6 @@ from dotenv import load_dotenv
3
  from langchain_community.document_loaders import TextLoader, DirectoryLoader, UnstructuredPDFLoader, UnstructuredWordDocumentLoader
4
  from langchain_google_genai import ChatGoogleGenerativeAI
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
- from langchain.chains import RetrievalQA
7
  from langchain.prompts import PromptTemplate
8
  import json
9
  from google.oauth2 import service_account
@@ -14,36 +13,31 @@ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
14
  if GEMINI_API_KEY is None:
15
  GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
16
 
17
- conf = os.environ.get('GOOGLE_APPLICATION_CREDENTIALS')
 
 
18
 
19
- service_account_info = json.loads(conf)
20
- service_account_info = eval(service_account_info)
21
-
22
- credentials = service_account.Credentials.from_service_account_info(service_account_info)
23
 
 
 
 
 
 
24
  DOCUMENT_DIR = 'document/'
25
  COLLECTION_NAME = "health_documents"
26
 
27
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", GEMINI_API_KEY=GEMINI_API_KEY, temperature=0.7, credentials=credentials)
28
 
29
 
30
  print("Models initialized successfully.")
31
 
32
- import os
33
- from dotenv import load_dotenv
34
- from langchain_community.llms import HuggingFacePipeline
35
  from langchain_huggingface import HuggingFaceEmbeddings
36
- from langchain_community.document_loaders import TextLoader, DirectoryLoader, UnstructuredPDFLoader, UnstructuredWordDocumentLoader
37
- from langchain.text_splitter import RecursiveCharacterTextSplitter
38
- from langchain.chains import RetrievalQA
39
- from langchain.prompts import PromptTemplate
40
- from langchain_community.vectorstores.utils import filter_complex_metadata
41
  from langchain_community.vectorstores import Chroma
42
  import torch
43
  from constants import CHROMA_PATH
44
 
45
  # Load environment variables (if needed)
46
- load_dotenv()
47
 
48
  # Define the directory containing the documents
49
  DOCUMENT_DIR = 'document/'
@@ -191,7 +185,6 @@ def create_health_agent(vector_store):
191
  docs = retriever.get_relevant_documents(query)
192
  context = "\n".join([doc.page_content for doc in docs])
193
 
194
- # Prepare inputs for the LLM chain
195
  llm_inputs = {
196
  'context': context,
197
  'question': query,
@@ -203,12 +196,11 @@ def create_health_agent(vector_store):
203
  result = self.llm_chain(llm_inputs)
204
  return {'result': result['text']}
205
 
206
- # Create the LLM chain
207
  llm_chain = LLMChain(llm=llm, prompt=PROMPT)
208
 
209
- # Create and return the custom chain
210
  return CustomRetrievalQA(retriever=retriever, llm_chain=llm_chain, user_data=None)
211
 
 
212
 
213
  def agent_with_db():
214
  # 1. Load documents
 
3
  from langchain_community.document_loaders import TextLoader, DirectoryLoader, UnstructuredPDFLoader, UnstructuredWordDocumentLoader
4
  from langchain_google_genai import ChatGoogleGenerativeAI
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
6
  from langchain.prompts import PromptTemplate
7
  import json
8
  from google.oauth2 import service_account
 
13
  if GEMINI_API_KEY is None:
14
  GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
15
 
16
+ prod = os.environ.get("PROD")
17
+ if prod == "true":
18
+ conf = os.environ.get('GOOGLE_APPLICATION_CREDENTIALS')
19
 
20
+ service_account_info = json.loads(conf)
21
+ service_account_info = eval(service_account_info)
 
 
22
 
23
+ credentials = service_account.Credentials.from_service_account_info(service_account_info)
24
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", GEMINI_API_KEY=GEMINI_API_KEY, temperature=0.7, credentials=credentials)
25
+ else:
26
+ # Initialize the language model with your API key
27
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", GEMINI_API_KEY=GEMINI_API_KEY, temperature=0.7)
28
  DOCUMENT_DIR = 'document/'
29
  COLLECTION_NAME = "health_documents"
30
 
 
31
 
32
 
33
  print("Models initialized successfully.")
34
 
 
 
 
35
  from langchain_huggingface import HuggingFaceEmbeddings
 
 
 
 
 
36
  from langchain_community.vectorstores import Chroma
37
  import torch
38
  from constants import CHROMA_PATH
39
 
40
  # Load environment variables (if needed)
 
41
 
42
  # Define the directory containing the documents
43
  DOCUMENT_DIR = 'document/'
 
185
  docs = retriever.get_relevant_documents(query)
186
  context = "\n".join([doc.page_content for doc in docs])
187
 
 
188
  llm_inputs = {
189
  'context': context,
190
  'question': query,
 
196
  result = self.llm_chain(llm_inputs)
197
  return {'result': result['text']}
198
 
 
199
  llm_chain = LLMChain(llm=llm, prompt=PROMPT)
200
 
 
201
  return CustomRetrievalQA(retriever=retriever, llm_chain=llm_chain, user_data=None)
202
 
203
+ # from langgraph_agent import initialize_health_agent
204
 
205
  def agent_with_db():
206
  # 1. Load documents
fetch.py CHANGED
@@ -5,7 +5,7 @@ import json
5
  Contains the example code to retrieve response from the server in python-requests"""
6
 
7
  ## without previous_state
8
- url = "https://arpit-bansal-healthbridge.hf.space/"
9
  headers = {
10
  "accept": "application/json",
11
  "Content-Type": "application/json"
 
5
  Contains the example code to retrieve response from the server in python-requests"""
6
 
7
  ## without previous_state
8
+ url = "https://arpit-bansal-healthbridge.hf.space/retrieve"
9
  headers = {
10
  "accept": "application/json",
11
  "Content-Type": "application/json"
langgraph_agent.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import LangGraph components
2
+ from langgraph.graph import StateGraph, END
3
+ from typing import TypedDict, List, Dict, Any, Annotated, Union, Literal
4
+ import operator
5
+ from pydantic import BaseModel, Field
6
+ from agent import LLMChain, PromptTemplate, llm, DOCUMENT_DIR, load_documents, split_documents, CHROMA_PATH, load_vectordb, create_and_store_embeddings
7
+ import os
8
+ # Define state schema
9
+ class AgentState(TypedDict):
10
+ query: str
11
+ previous_conversation: str
12
+ user_data: Dict[str, Any]
13
+ requires_rag: bool
14
+ context: List[str]
15
+ response: str
16
+
17
+ # Define tools and nodes for the LangGraph
18
+
19
+ def query_classifier(state: AgentState) -> AgentState:
20
+ """Determine if the query requires RAG retrieval."""
21
+ query_lower = state["query"].lower()
22
+ rag_keywords = [
23
+ "scheme", "schemes", "program", "programs", "policy", "policies",
24
+ "public health engineering", "phe", "public health", "government",
25
+ "benefit", "financial", "assistance", "aid", "initiative"
26
+ ]
27
+
28
+ state["requires_rag"] = any(keyword in query_lower for keyword in rag_keywords)
29
+ return state
30
+
31
+ def retrieve_documents(state: AgentState) -> AgentState:
32
+ """Retrieve documents from vector store if needed."""
33
+ if state["requires_rag"]:
34
+ # Get the global vector_store variable
35
+ # This assumes vector_store is accessible in this scope
36
+ docs = vector_store.as_retriever(search_kwargs={"k": 5}).get_relevant_documents(state["query"])
37
+ state["context"] = [doc.page_content for doc in docs]
38
+ else:
39
+ state["context"] = []
40
+ return state
41
+
42
+ def generate_response(state: AgentState) -> AgentState:
43
+ """Generate response with or without context."""
44
+ # style = state["user_data"].get("style", "normal") if isinstance(state["user_data"], dict) else "normal"
45
+
46
+ base_prompt = """You are a helpful health assistant. Who will talk to the user as human and resolve their queries.
47
+
48
+ Use Previous_Conversation to maintain consistency in the conversation.
49
+ These are Previous_Conversation between you and user.
50
+ Previous_Conversation: {previous_conversation}
51
+
52
+ These are info about the person.
53
+ User_Data: {user_data}
54
+
55
+ Points to Adhere:
56
+ 1. Only tell the schemes if user specifically asked, otherwise don't share schemes information.
57
+ 2. If the user asks about schemes, Ask about what state they belong to first.
58
+ 3. You can act as a mental-health counselor if needed.
59
+ 4. Give precautions and natural-remedies for the diseases, if user asked or it's needed, only for Common diseases include the common cold, flu etc.
60
+ 5. Ask the preferred language of the user, In the starting of the conversation.
61
+ 6. Give the answer in a friendly and conversational tone.
62
+ 7. Style to answer in {style} way.
63
+ Question: {question}
64
+ """
65
+
66
+ if state["requires_rag"] and state["context"]:
67
+ # Add context to prompt if we're using RAG
68
+ context = "\n".join(state["context"])
69
+ prompt_template = base_prompt + "\nContext from knowledge base:\n{context}\n\nAnswer:"
70
+ prompt = PromptTemplate(
71
+ template=prompt_template,
72
+ input_variables=["context", "question", "previous_conversation", "user_data", "style"]
73
+ )
74
+
75
+ llm_chain = LLMChain(llm=llm, prompt=prompt)
76
+ result = llm_chain({
77
+ 'context': context,
78
+ 'question': state["query"],
79
+ 'previous_conversation': state["previous_conversation"],
80
+ 'user_data': state["user_data"],
81
+ 'style': state["user_data"].get("style", "normal")
82
+ })
83
+ else:
84
+ # Answer directly without context
85
+ prompt_template = base_prompt + "\nAnswer:"
86
+ prompt = PromptTemplate(
87
+ template=prompt_template,
88
+ input_variables=["question", "previous_conversation", "user_data", "style"]
89
+ )
90
+
91
+ llm_chain = LLMChain(llm=llm, prompt=prompt)
92
+ result = llm_chain({
93
+ 'question': state["query"],
94
+ 'previous_conversation': state["previous_conversation"],
95
+ 'user_data': state["user_data"],
96
+ 'style': state["user_data"].get("style", "normal")
97
+ })
98
+
99
+ state["response"] = result["text"]
100
+ return state
101
+
102
+ def create_agent_workflow():
103
+ """Create the LangGraph workflow for the health agent."""
104
+ # Initialize the state graph
105
+ workflow = StateGraph(AgentState)
106
+
107
+ # Add nodes
108
+ workflow.add_node("classifier", query_classifier)
109
+ workflow.add_node("retriever", retrieve_documents)
110
+ workflow.add_node("responder", generate_response)
111
+
112
+ # Create edges
113
+ workflow.add_edge("classifier", "retriever")
114
+ workflow.add_edge("retriever", "responder")
115
+ workflow.add_edge("responder", END)
116
+
117
+ # Set the entry point
118
+ workflow.set_entry_point("classifier")
119
+
120
+ # Compile the graph
121
+ return workflow.compile()
122
+
123
+ def agent_with_db():
124
+ # Load or create vector store
125
+ global vector_store
126
+ vector_store = None
127
+ try:
128
+ vector_store = load_vectordb(CHROMA_PATH)
129
+ except ValueError:
130
+ pass
131
+
132
+ UPDATE_DB = os.getenv("UPDATE_DB", "false")
133
+ if UPDATE_DB.lower() == "true" or vector_store is None:
134
+ print("Loading and processing documents...")
135
+ documents = load_documents(DOCUMENT_DIR)
136
+ chunks = split_documents(documents)
137
+
138
+ try:
139
+ vector_store = create_and_store_embeddings(chunks)
140
+ except Exception as e:
141
+ print(f"Error creating embeddings: {e}")
142
+ return None
143
+
144
+ print("Creating the LangGraph health agent workflow...")
145
+ agent_workflow = create_agent_workflow()
146
+
147
+ class HealthAgent:
148
+ def __init__(self, workflow):
149
+ self.workflow = workflow
150
+ self.conversation_history = ""
151
+
152
+ def __call__(self, input_data):
153
+ # Handle both dictionary input and direct arguments
154
+ if isinstance(input_data, dict):
155
+ query = input_data.get("query", "")
156
+ previous_conversation = input_data.get("previous_conversation", "")
157
+ user_data = input_data.get("user_data", {})
158
+ style = input_data.get("style", "normal")
159
+ else:
160
+ # Assume it's a direct query string
161
+ query = input_data
162
+ previous_conversation = ""
163
+ user_data = {}
164
+ style = "normal"
165
+
166
+ # Store previous conversation if provided
167
+ if previous_conversation:
168
+ self.conversation_history = previous_conversation
169
+
170
+ # Update conversation history
171
+ if self.conversation_history:
172
+ self.conversation_history += f"\nUser: {query}\n"
173
+ else:
174
+ self.conversation_history = f"User: {query}\n"
175
+
176
+ if "style" not in user_data:
177
+ user_data["style"] = style
178
+ # Prepare initial state
179
+ initial_state = {
180
+ "query": query,
181
+ "previous_conversation": self.conversation_history,
182
+ "user_data": user_data,
183
+ "requires_rag": False,
184
+ "context": [],
185
+ "response": "",
186
+ # "style": style
187
+ }
188
+ print("Initial state:", initial_state)
189
+
190
+ # Run the workflow
191
+ final_state = self.workflow.invoke(initial_state)
192
+ print("Final state:", final_state)
193
+
194
+ # Update conversation history with response
195
+ self.conversation_history += f"Assistant: {final_state['response']}\n"
196
+
197
+ # Return in the expected format
198
+ return {"result": final_state["response"]}
199
+
200
+ return HealthAgent(agent_workflow)
main.py CHANGED
@@ -1,8 +1,9 @@
1
- from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
3
  import os
4
  from fastapi import HTTPException
5
- from agent import agent_with_db
 
6
  from schemas import request
7
  from dotenv import load_dotenv
8
  load_dotenv()
@@ -44,7 +45,7 @@ async def parse_user_data(user_data):
44
  return user_info
45
 
46
  @app.post("/retrieve", status_code=200)
47
- async def retrieve(request:request):
48
  try:
49
  prev_conv = request.previous_state
50
  user_info = await parse_user_data(request.user_data)
@@ -53,9 +54,14 @@ async def retrieve(request:request):
53
  prev_conv = "No previous conversation available, first time"
54
  query = request.query
55
  prev_conv = str(prev_conv)
56
- user_info = str(user_info)
57
- response = agent({"query": query, "previous_conversation": prev_conv, "user_data": user_info})
58
-
 
 
 
 
 
59
  return {"response": response["result"]}
60
 
61
  except Exception as e:
 
1
+ from fastapi import FastAPI, Request
2
  from fastapi.middleware.cors import CORSMiddleware
3
  import os
4
  from fastapi import HTTPException
5
+ # from agent import agent_with_db
6
+ from langgraph_agent import agent_with_db
7
  from schemas import request
8
  from dotenv import load_dotenv
9
  load_dotenv()
 
45
  return user_info
46
 
47
  @app.post("/retrieve", status_code=200)
48
+ async def retrieve(request:request, url:Request):
49
  try:
50
  prev_conv = request.previous_state
51
  user_info = await parse_user_data(request.user_data)
 
54
  prev_conv = "No previous conversation available, first time"
55
  query = request.query
56
  prev_conv = str(prev_conv)
57
+ # user_info = str(user_info) # Was needed in Old-Rag not needed in LangGraph-Rag.
58
+ # Did a mistake by choosing to string format for Old-Rag
59
+ response = agent({"query": query, "previous_conversation": prev_conv, "user_data": user_info, "style": request.user_data["style"]})
60
+ origin = url.headers.get('origin')
61
+ if origin is None:
62
+ origin = url.headers.get('referer')
63
+ print("origin: ", origin)
64
+ print("response: ", response)
65
  return {"response": response["result"]}
66
 
67
  except Exception as e: