Adiii143 commited on
Commit
16c2f38
·
verified ·
1 Parent(s): 081aad7
Files changed (1) hide show
  1. agent.py +78 -24
agent.py CHANGED
@@ -1,25 +1,74 @@
 
1
  import os
2
  from dotenv import load_dotenv
3
  from langgraph.graph import START, StateGraph, MessagesState
4
  from langgraph.prebuilt import tools_condition
5
  from langgraph.prebuilt import ToolNode
6
  from langchain_google_genai import ChatGoogleGenerativeAI
7
- from langchain_nvidia_ai_endpoints import ChatNVIDIA
8
-
9
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
  from langchain_community.tools.tavily_search import TavilySearchResults
11
  from langchain_community.document_loaders import WikipediaLoader
12
  from langchain_community.document_loaders import ArxivLoader
 
13
  from langchain_core.messages import SystemMessage, HumanMessage
14
  from langchain_core.tools import tool
15
  from langchain.tools.retriever import create_retriever_tool
16
- from langchain_chroma import Chroma
17
- from langchain_core.messages import SystemMessage, HumanMessage
18
- from langchain.tools.retriever import create_retriever_tool
19
-
20
 
21
  load_dotenv()
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  @tool
24
  def wiki_search(query: str) -> str:
25
  """Search Wikipedia for a query and return maximum 2 results.
@@ -64,38 +113,43 @@ def arvix_search(query: str) -> str:
64
 
65
 
66
 
67
-
68
  # load the system prompt from the file
69
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
70
  system_prompt = f.read()
71
 
 
72
  sys_msg = SystemMessage(content=system_prompt)
73
 
 
74
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
75
-
76
- vector_store = Chroma(
77
- collection_name="example_collection",
78
- embedding_function=embeddings,
79
- persist_directory="./chroma_langchain_db_1", # Where to save data locally, remove if not necessary
 
 
 
80
  )
81
-
82
-
83
- # Assign the result to a new variable name, like 'question_retriever_tool'
84
- question_retriever_tool = create_retriever_tool(
85
  retriever=vector_store.as_retriever(),
86
- name="question_search", # Changed name to be valid
87
  description="A tool to retrieve similar questions from a vector store.",
88
  )
89
 
 
 
90
  tools = [
 
 
 
 
 
91
  wiki_search,
92
  web_search,
93
  arvix_search,
94
- question_retriever_tool,
95
  ]
96
 
97
-
98
-
99
  # Build graph function
100
  def build_graph(provider: str = "groq"):
101
  """Build the graph"""
@@ -103,9 +157,9 @@ def build_graph(provider: str = "groq"):
103
  if provider == "google":
104
  # Google Gemini
105
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
106
- elif provider == "Nvidia":
107
-
108
- llm = ChatNVIDIA(model="meta/llama-3.1-70b-instruct", temperature=0)
109
  elif provider == "huggingface":
110
  # TODO: Add huggingface endpoint
111
  llm = ChatHuggingFace(
@@ -151,7 +205,7 @@ def build_graph(provider: str = "groq"):
151
  if __name__ == "__main__":
152
  question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
153
  # Build the graph
154
- graph = build_graph(provider="Nvidia")
155
  # Run the graph
156
  messages = [HumanMessage(content=question)]
157
  messages = graph.invoke({"messages": messages})
 
1
+ """LangGraph Agent"""
2
  import os
3
  from dotenv import load_dotenv
4
  from langgraph.graph import START, StateGraph, MessagesState
5
  from langgraph.prebuilt import tools_condition
6
  from langgraph.prebuilt import ToolNode
7
  from langchain_google_genai import ChatGoogleGenerativeAI
8
+ from langchain_groq import ChatGroq
 
9
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
  from langchain_community.tools.tavily_search import TavilySearchResults
11
  from langchain_community.document_loaders import WikipediaLoader
12
  from langchain_community.document_loaders import ArxivLoader
13
+ from langchain_community.vectorstores import SupabaseVectorStore
14
  from langchain_core.messages import SystemMessage, HumanMessage
15
  from langchain_core.tools import tool
16
  from langchain.tools.retriever import create_retriever_tool
17
+ from supabase.client import Client, create_client
 
 
 
18
 
19
  load_dotenv()
20
 
21
+ @tool
22
+ def multiply(a: int, b: int) -> int:
23
+ """Multiply two numbers.
24
+ Args:
25
+ a: first int
26
+ b: second int
27
+ """
28
+ return a * b
29
+
30
+ @tool
31
+ def add(a: int, b: int) -> int:
32
+ """Add two numbers.
33
+
34
+ Args:
35
+ a: first int
36
+ b: second int
37
+ """
38
+ return a + b
39
+
40
+ @tool
41
+ def subtract(a: int, b: int) -> int:
42
+ """Subtract two numbers.
43
+
44
+ Args:
45
+ a: first int
46
+ b: second int
47
+ """
48
+ return a - b
49
+
50
+ @tool
51
+ def divide(a: int, b: int) -> int:
52
+ """Divide two numbers.
53
+
54
+ Args:
55
+ a: first int
56
+ b: second int
57
+ """
58
+ if b == 0:
59
+ raise ValueError("Cannot divide by zero.")
60
+ return a / b
61
+
62
+ @tool
63
+ def modulus(a: int, b: int) -> int:
64
+ """Get the modulus of two numbers.
65
+
66
+ Args:
67
+ a: first int
68
+ b: second int
69
+ """
70
+ return a % b
71
+
72
  @tool
73
  def wiki_search(query: str) -> str:
74
  """Search Wikipedia for a query and return maximum 2 results.
 
113
 
114
 
115
 
 
116
  # load the system prompt from the file
117
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
118
  system_prompt = f.read()
119
 
120
+ # System message
121
  sys_msg = SystemMessage(content=system_prompt)
122
 
123
+ # build a retriever
124
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
125
+ supabase: Client = create_client(
126
+ os.environ.get("SUPABASE_URL"),
127
+ os.environ.get("SUPABASE_SERVICE_KEY"))
128
+ vector_store = SupabaseVectorStore(
129
+ client=supabase,
130
+ embedding= embeddings,
131
+ table_name="documents",
132
+ query_name="match_documents_langchain",
133
  )
134
+ create_retriever_tool = create_retriever_tool(
 
 
 
135
  retriever=vector_store.as_retriever(),
136
+ name="Question Search",
137
  description="A tool to retrieve similar questions from a vector store.",
138
  )
139
 
140
+
141
+
142
  tools = [
143
+ multiply,
144
+ add,
145
+ subtract,
146
+ divide,
147
+ modulus,
148
  wiki_search,
149
  web_search,
150
  arvix_search,
 
151
  ]
152
 
 
 
153
  # Build graph function
154
  def build_graph(provider: str = "groq"):
155
  """Build the graph"""
 
157
  if provider == "google":
158
  # Google Gemini
159
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
160
+ elif provider == "groq":
161
+ # Groq https://console.groq.com/docs/models
162
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
163
  elif provider == "huggingface":
164
  # TODO: Add huggingface endpoint
165
  llm = ChatHuggingFace(
 
205
  if __name__ == "__main__":
206
  question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
207
  # Build the graph
208
+ graph = build_graph(provider="groq")
209
  # Run the graph
210
  messages = [HumanMessage(content=question)]
211
  messages = graph.invoke({"messages": messages})