Spaces:
Running
Running
Commit
·
6a04a88
1
Parent(s):
28e1a34
Langgraph-Agent V0.5
Browse files- .gitignore +5 -1
- agent.py +11 -19
- fetch.py +1 -1
- langgraph_agent.py +200 -0
- main.py +12 -6
.gitignore
CHANGED
@@ -2,4 +2,8 @@ agent_advance.py
|
|
2 |
*.ipynb
|
3 |
__pycache__/
|
4 |
hackathon-healthcare-solutions-9e6f46d0a21e.json
|
5 |
-
venv/
|
|
|
|
|
|
|
|
|
|
2 |
*.ipynb
|
3 |
__pycache__/
|
4 |
hackathon-healthcare-solutions-9e6f46d0a21e.json
|
5 |
+
venv/
|
6 |
+
.vscode/
|
7 |
+
.env
|
8 |
+
test.py
|
9 |
+
app.py
|
agent.py
CHANGED
@@ -3,7 +3,6 @@ from dotenv import load_dotenv
|
|
3 |
from langchain_community.document_loaders import TextLoader, DirectoryLoader, UnstructuredPDFLoader, UnstructuredWordDocumentLoader
|
4 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
5 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
6 |
-
from langchain.chains import RetrievalQA
|
7 |
from langchain.prompts import PromptTemplate
|
8 |
import json
|
9 |
from google.oauth2 import service_account
|
@@ -14,36 +13,31 @@ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
|
14 |
if GEMINI_API_KEY is None:
|
15 |
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
|
16 |
|
17 |
-
|
|
|
|
|
18 |
|
19 |
-
service_account_info = json.loads(conf)
|
20 |
-
service_account_info = eval(service_account_info)
|
21 |
-
|
22 |
-
credentials = service_account.Credentials.from_service_account_info(service_account_info)
|
23 |
|
|
|
|
|
|
|
|
|
|
|
24 |
DOCUMENT_DIR = 'document/'
|
25 |
COLLECTION_NAME = "health_documents"
|
26 |
|
27 |
-
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", GEMINI_API_KEY=GEMINI_API_KEY, temperature=0.7, credentials=credentials)
|
28 |
|
29 |
|
30 |
print("Models initialized successfully.")
|
31 |
|
32 |
-
import os
|
33 |
-
from dotenv import load_dotenv
|
34 |
-
from langchain_community.llms import HuggingFacePipeline
|
35 |
from langchain_huggingface import HuggingFaceEmbeddings
|
36 |
-
from langchain_community.document_loaders import TextLoader, DirectoryLoader, UnstructuredPDFLoader, UnstructuredWordDocumentLoader
|
37 |
-
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
38 |
-
from langchain.chains import RetrievalQA
|
39 |
-
from langchain.prompts import PromptTemplate
|
40 |
-
from langchain_community.vectorstores.utils import filter_complex_metadata
|
41 |
from langchain_community.vectorstores import Chroma
|
42 |
import torch
|
43 |
from constants import CHROMA_PATH
|
44 |
|
45 |
# Load environment variables (if needed)
|
46 |
-
load_dotenv()
|
47 |
|
48 |
# Define the directory containing the documents
|
49 |
DOCUMENT_DIR = 'document/'
|
@@ -191,7 +185,6 @@ def create_health_agent(vector_store):
|
|
191 |
docs = retriever.get_relevant_documents(query)
|
192 |
context = "\n".join([doc.page_content for doc in docs])
|
193 |
|
194 |
-
# Prepare inputs for the LLM chain
|
195 |
llm_inputs = {
|
196 |
'context': context,
|
197 |
'question': query,
|
@@ -203,12 +196,11 @@ def create_health_agent(vector_store):
|
|
203 |
result = self.llm_chain(llm_inputs)
|
204 |
return {'result': result['text']}
|
205 |
|
206 |
-
# Create the LLM chain
|
207 |
llm_chain = LLMChain(llm=llm, prompt=PROMPT)
|
208 |
|
209 |
-
# Create and return the custom chain
|
210 |
return CustomRetrievalQA(retriever=retriever, llm_chain=llm_chain, user_data=None)
|
211 |
|
|
|
212 |
|
213 |
def agent_with_db():
|
214 |
# 1. Load documents
|
|
|
3 |
from langchain_community.document_loaders import TextLoader, DirectoryLoader, UnstructuredPDFLoader, UnstructuredWordDocumentLoader
|
4 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
5 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
6 |
from langchain.prompts import PromptTemplate
|
7 |
import json
|
8 |
from google.oauth2 import service_account
|
|
|
13 |
if GEMINI_API_KEY is None:
|
14 |
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
|
15 |
|
16 |
+
prod = os.environ.get("PROD")
|
17 |
+
if prod == "true":
|
18 |
+
conf = os.environ.get('GOOGLE_APPLICATION_CREDENTIALS')
|
19 |
|
20 |
+
service_account_info = json.loads(conf)
|
21 |
+
service_account_info = eval(service_account_info)
|
|
|
|
|
22 |
|
23 |
+
credentials = service_account.Credentials.from_service_account_info(service_account_info)
|
24 |
+
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", GEMINI_API_KEY=GEMINI_API_KEY, temperature=0.7, credentials=credentials)
|
25 |
+
else:
|
26 |
+
# Initialize the language model with your API key
|
27 |
+
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", GEMINI_API_KEY=GEMINI_API_KEY, temperature=0.7)
|
28 |
DOCUMENT_DIR = 'document/'
|
29 |
COLLECTION_NAME = "health_documents"
|
30 |
|
|
|
31 |
|
32 |
|
33 |
print("Models initialized successfully.")
|
34 |
|
|
|
|
|
|
|
35 |
from langchain_huggingface import HuggingFaceEmbeddings
|
|
|
|
|
|
|
|
|
|
|
36 |
from langchain_community.vectorstores import Chroma
|
37 |
import torch
|
38 |
from constants import CHROMA_PATH
|
39 |
|
40 |
# Load environment variables (if needed)
|
|
|
41 |
|
42 |
# Define the directory containing the documents
|
43 |
DOCUMENT_DIR = 'document/'
|
|
|
185 |
docs = retriever.get_relevant_documents(query)
|
186 |
context = "\n".join([doc.page_content for doc in docs])
|
187 |
|
|
|
188 |
llm_inputs = {
|
189 |
'context': context,
|
190 |
'question': query,
|
|
|
196 |
result = self.llm_chain(llm_inputs)
|
197 |
return {'result': result['text']}
|
198 |
|
|
|
199 |
llm_chain = LLMChain(llm=llm, prompt=PROMPT)
|
200 |
|
|
|
201 |
return CustomRetrievalQA(retriever=retriever, llm_chain=llm_chain, user_data=None)
|
202 |
|
203 |
+
# from langgraph_agent import initialize_health_agent
|
204 |
|
205 |
def agent_with_db():
|
206 |
# 1. Load documents
|
fetch.py
CHANGED
@@ -5,7 +5,7 @@ import json
|
|
5 |
Contains the example code to retrieve response from the server in python-requests"""
|
6 |
|
7 |
## without previous_state
|
8 |
-
url = "https://arpit-bansal-healthbridge.hf.space/"
|
9 |
headers = {
|
10 |
"accept": "application/json",
|
11 |
"Content-Type": "application/json"
|
|
|
5 |
Contains the example code to retrieve response from the server in python-requests"""
|
6 |
|
7 |
## without previous_state
|
8 |
+
url = "https://arpit-bansal-healthbridge.hf.space/retrieve"
|
9 |
headers = {
|
10 |
"accept": "application/json",
|
11 |
"Content-Type": "application/json"
|
langgraph_agent.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Import LangGraph components
|
2 |
+
from langgraph.graph import StateGraph, END
|
3 |
+
from typing import TypedDict, List, Dict, Any, Annotated, Union, Literal
|
4 |
+
import operator
|
5 |
+
from pydantic import BaseModel, Field
|
6 |
+
from agent import LLMChain, PromptTemplate, llm, DOCUMENT_DIR, load_documents, split_documents, CHROMA_PATH, load_vectordb, create_and_store_embeddings
|
7 |
+
import os
|
8 |
+
# Define state schema
|
9 |
+
class AgentState(TypedDict):
|
10 |
+
query: str
|
11 |
+
previous_conversation: str
|
12 |
+
user_data: Dict[str, Any]
|
13 |
+
requires_rag: bool
|
14 |
+
context: List[str]
|
15 |
+
response: str
|
16 |
+
|
17 |
+
# Define tools and nodes for the LangGraph
|
18 |
+
|
19 |
+
def query_classifier(state: AgentState) -> AgentState:
|
20 |
+
"""Determine if the query requires RAG retrieval."""
|
21 |
+
query_lower = state["query"].lower()
|
22 |
+
rag_keywords = [
|
23 |
+
"scheme", "schemes", "program", "programs", "policy", "policies",
|
24 |
+
"public health engineering", "phe", "public health", "government",
|
25 |
+
"benefit", "financial", "assistance", "aid", "initiative"
|
26 |
+
]
|
27 |
+
|
28 |
+
state["requires_rag"] = any(keyword in query_lower for keyword in rag_keywords)
|
29 |
+
return state
|
30 |
+
|
31 |
+
def retrieve_documents(state: AgentState) -> AgentState:
|
32 |
+
"""Retrieve documents from vector store if needed."""
|
33 |
+
if state["requires_rag"]:
|
34 |
+
# Get the global vector_store variable
|
35 |
+
# This assumes vector_store is accessible in this scope
|
36 |
+
docs = vector_store.as_retriever(search_kwargs={"k": 5}).get_relevant_documents(state["query"])
|
37 |
+
state["context"] = [doc.page_content for doc in docs]
|
38 |
+
else:
|
39 |
+
state["context"] = []
|
40 |
+
return state
|
41 |
+
|
42 |
+
def generate_response(state: AgentState) -> AgentState:
|
43 |
+
"""Generate response with or without context."""
|
44 |
+
# style = state["user_data"].get("style", "normal") if isinstance(state["user_data"], dict) else "normal"
|
45 |
+
|
46 |
+
base_prompt = """You are a helpful health assistant. Who will talk to the user as human and resolve their queries.
|
47 |
+
|
48 |
+
Use Previous_Conversation to maintain consistency in the conversation.
|
49 |
+
These are Previous_Conversation between you and user.
|
50 |
+
Previous_Conversation: {previous_conversation}
|
51 |
+
|
52 |
+
These are info about the person.
|
53 |
+
User_Data: {user_data}
|
54 |
+
|
55 |
+
Points to Adhere:
|
56 |
+
1. Only tell the schemes if user specifically asked, otherwise don't share schemes information.
|
57 |
+
2. If the user asks about schemes, Ask about what state they belong to first.
|
58 |
+
3. You can act as a mental-health counselor if needed.
|
59 |
+
4. Give precautions and natural-remedies for the diseases, if user asked or it's needed, only for Common diseases include the common cold, flu etc.
|
60 |
+
5. Ask the preferred language of the user, In the starting of the conversation.
|
61 |
+
6. Give the answer in a friendly and conversational tone.
|
62 |
+
7. Style to answer in {style} way.
|
63 |
+
Question: {question}
|
64 |
+
"""
|
65 |
+
|
66 |
+
if state["requires_rag"] and state["context"]:
|
67 |
+
# Add context to prompt if we're using RAG
|
68 |
+
context = "\n".join(state["context"])
|
69 |
+
prompt_template = base_prompt + "\nContext from knowledge base:\n{context}\n\nAnswer:"
|
70 |
+
prompt = PromptTemplate(
|
71 |
+
template=prompt_template,
|
72 |
+
input_variables=["context", "question", "previous_conversation", "user_data", "style"]
|
73 |
+
)
|
74 |
+
|
75 |
+
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
76 |
+
result = llm_chain({
|
77 |
+
'context': context,
|
78 |
+
'question': state["query"],
|
79 |
+
'previous_conversation': state["previous_conversation"],
|
80 |
+
'user_data': state["user_data"],
|
81 |
+
'style': state["user_data"].get("style", "normal")
|
82 |
+
})
|
83 |
+
else:
|
84 |
+
# Answer directly without context
|
85 |
+
prompt_template = base_prompt + "\nAnswer:"
|
86 |
+
prompt = PromptTemplate(
|
87 |
+
template=prompt_template,
|
88 |
+
input_variables=["question", "previous_conversation", "user_data", "style"]
|
89 |
+
)
|
90 |
+
|
91 |
+
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
92 |
+
result = llm_chain({
|
93 |
+
'question': state["query"],
|
94 |
+
'previous_conversation': state["previous_conversation"],
|
95 |
+
'user_data': state["user_data"],
|
96 |
+
'style': state["user_data"].get("style", "normal")
|
97 |
+
})
|
98 |
+
|
99 |
+
state["response"] = result["text"]
|
100 |
+
return state
|
101 |
+
|
102 |
+
def create_agent_workflow():
|
103 |
+
"""Create the LangGraph workflow for the health agent."""
|
104 |
+
# Initialize the state graph
|
105 |
+
workflow = StateGraph(AgentState)
|
106 |
+
|
107 |
+
# Add nodes
|
108 |
+
workflow.add_node("classifier", query_classifier)
|
109 |
+
workflow.add_node("retriever", retrieve_documents)
|
110 |
+
workflow.add_node("responder", generate_response)
|
111 |
+
|
112 |
+
# Create edges
|
113 |
+
workflow.add_edge("classifier", "retriever")
|
114 |
+
workflow.add_edge("retriever", "responder")
|
115 |
+
workflow.add_edge("responder", END)
|
116 |
+
|
117 |
+
# Set the entry point
|
118 |
+
workflow.set_entry_point("classifier")
|
119 |
+
|
120 |
+
# Compile the graph
|
121 |
+
return workflow.compile()
|
122 |
+
|
123 |
+
def agent_with_db():
|
124 |
+
# Load or create vector store
|
125 |
+
global vector_store
|
126 |
+
vector_store = None
|
127 |
+
try:
|
128 |
+
vector_store = load_vectordb(CHROMA_PATH)
|
129 |
+
except ValueError:
|
130 |
+
pass
|
131 |
+
|
132 |
+
UPDATE_DB = os.getenv("UPDATE_DB", "false")
|
133 |
+
if UPDATE_DB.lower() == "true" or vector_store is None:
|
134 |
+
print("Loading and processing documents...")
|
135 |
+
documents = load_documents(DOCUMENT_DIR)
|
136 |
+
chunks = split_documents(documents)
|
137 |
+
|
138 |
+
try:
|
139 |
+
vector_store = create_and_store_embeddings(chunks)
|
140 |
+
except Exception as e:
|
141 |
+
print(f"Error creating embeddings: {e}")
|
142 |
+
return None
|
143 |
+
|
144 |
+
print("Creating the LangGraph health agent workflow...")
|
145 |
+
agent_workflow = create_agent_workflow()
|
146 |
+
|
147 |
+
class HealthAgent:
|
148 |
+
def __init__(self, workflow):
|
149 |
+
self.workflow = workflow
|
150 |
+
self.conversation_history = ""
|
151 |
+
|
152 |
+
def __call__(self, input_data):
|
153 |
+
# Handle both dictionary input and direct arguments
|
154 |
+
if isinstance(input_data, dict):
|
155 |
+
query = input_data.get("query", "")
|
156 |
+
previous_conversation = input_data.get("previous_conversation", "")
|
157 |
+
user_data = input_data.get("user_data", {})
|
158 |
+
style = input_data.get("style", "normal")
|
159 |
+
else:
|
160 |
+
# Assume it's a direct query string
|
161 |
+
query = input_data
|
162 |
+
previous_conversation = ""
|
163 |
+
user_data = {}
|
164 |
+
style = "normal"
|
165 |
+
|
166 |
+
# Store previous conversation if provided
|
167 |
+
if previous_conversation:
|
168 |
+
self.conversation_history = previous_conversation
|
169 |
+
|
170 |
+
# Update conversation history
|
171 |
+
if self.conversation_history:
|
172 |
+
self.conversation_history += f"\nUser: {query}\n"
|
173 |
+
else:
|
174 |
+
self.conversation_history = f"User: {query}\n"
|
175 |
+
|
176 |
+
if "style" not in user_data:
|
177 |
+
user_data["style"] = style
|
178 |
+
# Prepare initial state
|
179 |
+
initial_state = {
|
180 |
+
"query": query,
|
181 |
+
"previous_conversation": self.conversation_history,
|
182 |
+
"user_data": user_data,
|
183 |
+
"requires_rag": False,
|
184 |
+
"context": [],
|
185 |
+
"response": "",
|
186 |
+
# "style": style
|
187 |
+
}
|
188 |
+
print("Initial state:", initial_state)
|
189 |
+
|
190 |
+
# Run the workflow
|
191 |
+
final_state = self.workflow.invoke(initial_state)
|
192 |
+
print("Final state:", final_state)
|
193 |
+
|
194 |
+
# Update conversation history with response
|
195 |
+
self.conversation_history += f"Assistant: {final_state['response']}\n"
|
196 |
+
|
197 |
+
# Return in the expected format
|
198 |
+
return {"result": final_state["response"]}
|
199 |
+
|
200 |
+
return HealthAgent(agent_workflow)
|
main.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
-
from fastapi import FastAPI
|
2 |
from fastapi.middleware.cors import CORSMiddleware
|
3 |
import os
|
4 |
from fastapi import HTTPException
|
5 |
-
from agent import agent_with_db
|
|
|
6 |
from schemas import request
|
7 |
from dotenv import load_dotenv
|
8 |
load_dotenv()
|
@@ -44,7 +45,7 @@ async def parse_user_data(user_data):
|
|
44 |
return user_info
|
45 |
|
46 |
@app.post("/retrieve", status_code=200)
|
47 |
-
async def retrieve(request:request):
|
48 |
try:
|
49 |
prev_conv = request.previous_state
|
50 |
user_info = await parse_user_data(request.user_data)
|
@@ -53,9 +54,14 @@ async def retrieve(request:request):
|
|
53 |
prev_conv = "No previous conversation available, first time"
|
54 |
query = request.query
|
55 |
prev_conv = str(prev_conv)
|
56 |
-
user_info = str(user_info)
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
59 |
return {"response": response["result"]}
|
60 |
|
61 |
except Exception as e:
|
|
|
1 |
+
from fastapi import FastAPI, Request
|
2 |
from fastapi.middleware.cors import CORSMiddleware
|
3 |
import os
|
4 |
from fastapi import HTTPException
|
5 |
+
# from agent import agent_with_db
|
6 |
+
from langgraph_agent import agent_with_db
|
7 |
from schemas import request
|
8 |
from dotenv import load_dotenv
|
9 |
load_dotenv()
|
|
|
45 |
return user_info
|
46 |
|
47 |
@app.post("/retrieve", status_code=200)
|
48 |
+
async def retrieve(request:request, url:Request):
|
49 |
try:
|
50 |
prev_conv = request.previous_state
|
51 |
user_info = await parse_user_data(request.user_data)
|
|
|
54 |
prev_conv = "No previous conversation available, first time"
|
55 |
query = request.query
|
56 |
prev_conv = str(prev_conv)
|
57 |
+
# user_info = str(user_info) # Was needed in Old-Rag not needed in LangGraph-Rag.
|
58 |
+
# Did a mistake by choosing to string format for Old-Rag
|
59 |
+
response = agent({"query": query, "previous_conversation": prev_conv, "user_data": user_info, "style": request.user_data["style"]})
|
60 |
+
origin = url.headers.get('origin')
|
61 |
+
if origin is None:
|
62 |
+
origin = url.headers.get('referer')
|
63 |
+
print("origin: ", origin)
|
64 |
+
print("response: ", response)
|
65 |
return {"response": response["result"]}
|
66 |
|
67 |
except Exception as e:
|