Datawithsarah commited on
Commit
6d24d35
·
1 Parent(s): de718ca

claude API usuage

Browse files
Files changed (2) hide show
  1. agent.py +24 -28
  2. app.py +4 -4
agent.py CHANGED
@@ -4,6 +4,7 @@ from dotenv import load_dotenv
4
  from langgraph.graph import START, StateGraph, MessagesState
5
  from langgraph.prebuilt import tools_condition
6
  from langgraph.prebuilt import ToolNode
 
7
  from langchain_google_genai import ChatGoogleGenerativeAI
8
  from langchain_groq import ChatGroq
9
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
@@ -22,46 +23,38 @@ load_dotenv()
22
  # === Tools ===
23
  @tool
24
  def multiply(a: int, b: int) -> int:
25
- """Multiply two integers."""
26
  return a * b
27
 
28
  @tool
29
  def add(a: int, b: int) -> int:
30
- """Add two integers."""
31
  return a + b
32
 
33
  @tool
34
  def subtract(a: int, b: int) -> int:
35
- """Subtract b from a."""
36
  return a - b
37
 
38
  @tool
39
  def divide(a: int, b: int) -> float:
40
- """Divide a by b."""
41
  if b == 0:
42
  raise ValueError("Cannot divide by zero.")
43
  return a / b
44
 
45
  @tool
46
  def modulus(a: int, b: int) -> int:
47
- """Return a modulo b."""
48
  return a % b
49
 
50
  @tool
51
  def wiki_search(query: str) -> str:
52
- """Search Wikipedia for a query."""
53
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
54
  return "\n\n---\n\n".join([doc.page_content for doc in search_docs])
55
 
56
  @tool
57
  def web_search(query: str) -> str:
58
- """Search the web for a query."""
59
  search_docs = TavilySearchResults(max_results=3).invoke(query=query)
60
  return "\n\n---\n\n".join([doc.page_content for doc in search_docs])
61
 
62
  @tool
63
  def arvix_search(query: str) -> str:
64
- """Search Arxiv for a query."""
65
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
66
  return "\n\n---\n\n".join([doc.page_content[:1000] for doc in search_docs])
67
 
@@ -70,9 +63,9 @@ with open("system_prompt.txt", "r", encoding="utf-8") as f:
70
  system_prompt = f.read()
71
  sys_msg = SystemMessage(content=system_prompt)
72
 
73
- # === Embedding and Supabase Setup ===
74
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
75
- supabase: Client = create_client(os.environ.get("SUPABASE_URL"), os.environ.get("SUPABASE_SERVICE_KEY"))
76
  vector_store = SupabaseVectorStore(
77
  client=supabase,
78
  embedding=embeddings,
@@ -80,24 +73,27 @@ vector_store = SupabaseVectorStore(
80
  query_name="match_documents_langchain",
81
  )
82
 
83
- # === Tools List ===
84
  tools = [multiply, add, subtract, divide, modulus, wiki_search, web_search, arvix_search]
85
 
86
- # === Graph Builder ===
87
- def build_graph(provider: str = "groq"):
88
- if provider == "google":
89
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
 
 
 
 
90
  elif provider == "groq":
91
  llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
 
 
92
  elif provider == "huggingface":
93
- llm = ChatHuggingFace(
94
- llm=HuggingFaceEndpoint(
95
- url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
96
- temperature=0,
97
- )
98
- )
99
  else:
100
- raise ValueError("Invalid provider.")
101
 
102
  llm_with_tools = llm.bind_tools(tools)
103
 
@@ -123,7 +119,6 @@ def build_graph(provider: str = "groq"):
123
  builder.add_node("assistant", assistant)
124
  builder.add_node("tools", ToolNode(tools))
125
  builder.add_node("formatter", formatter)
126
-
127
  builder.add_edge(START, "retriever")
128
  builder.add_edge("retriever", "assistant")
129
  builder.add_conditional_edges("assistant", tools_condition)
@@ -132,9 +127,10 @@ def build_graph(provider: str = "groq"):
132
 
133
  return builder.compile()
134
 
135
- # === Test Entry Point ===
136
  if __name__ == "__main__":
137
- graph = build_graph("groq")
138
- messages = graph.invoke({"messages": [HumanMessage(content="What is the capital of France?")]})
139
- for msg in messages["messages"]:
140
- msg.pretty_print()
 
 
4
  from langgraph.graph import START, StateGraph, MessagesState
5
  from langgraph.prebuilt import tools_condition
6
  from langgraph.prebuilt import ToolNode
7
+ from langchain_anthropic.ChatAnthropic import ChatAnthropi
8
  from langchain_google_genai import ChatGoogleGenerativeAI
9
  from langchain_groq import ChatGroq
10
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
 
23
  # === Tools ===
24
  @tool
25
  def multiply(a: int, b: int) -> int:
 
26
  return a * b
27
 
28
  @tool
29
  def add(a: int, b: int) -> int:
 
30
  return a + b
31
 
32
  @tool
33
  def subtract(a: int, b: int) -> int:
 
34
  return a - b
35
 
36
  @tool
37
  def divide(a: int, b: int) -> float:
 
38
  if b == 0:
39
  raise ValueError("Cannot divide by zero.")
40
  return a / b
41
 
42
  @tool
43
  def modulus(a: int, b: int) -> int:
 
44
  return a % b
45
 
46
  @tool
47
  def wiki_search(query: str) -> str:
 
48
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
49
  return "\n\n---\n\n".join([doc.page_content for doc in search_docs])
50
 
51
  @tool
52
  def web_search(query: str) -> str:
 
53
  search_docs = TavilySearchResults(max_results=3).invoke(query=query)
54
  return "\n\n---\n\n".join([doc.page_content for doc in search_docs])
55
 
56
  @tool
57
  def arvix_search(query: str) -> str:
 
58
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
59
  return "\n\n---\n\n".join([doc.page_content[:1000] for doc in search_docs])
60
 
 
63
  system_prompt = f.read()
64
  sys_msg = SystemMessage(content=system_prompt)
65
 
66
+ # === Embedding & Vector DB ===
67
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
68
+ supabase: Client = create_client(os.getenv("SUPABASE_URL"), os.getenv("SUPABASE_SERVICE_KEY"))
69
  vector_store = SupabaseVectorStore(
70
  client=supabase,
71
  embedding=embeddings,
 
73
  query_name="match_documents_langchain",
74
  )
75
 
76
+ # === Tools ===
77
  tools = [multiply, add, subtract, divide, modulus, wiki_search, web_search, arvix_search]
78
 
79
+ # === LangGraph Agent Definition ===
80
+ def build_graph(provider: str = "claude"):
81
+ if provider == "claude":
82
+ llm = ChatAnthropic(
83
+ model="claude-3-sonnet-20240229",
84
+ temperature=0,
85
+ anthropic_api_key=os.getenv("CLAUDE_API_KEY")
86
+ )
87
  elif provider == "groq":
88
  llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
89
+ elif provider == "google":
90
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
91
  elif provider == "huggingface":
92
+ llm = ChatHuggingFace(llm=HuggingFaceEndpoint(
93
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
94
+ temperature=0))
 
 
 
95
  else:
96
+ raise ValueError("Invalid provider")
97
 
98
  llm_with_tools = llm.bind_tools(tools)
99
 
 
119
  builder.add_node("assistant", assistant)
120
  builder.add_node("tools", ToolNode(tools))
121
  builder.add_node("formatter", formatter)
 
122
  builder.add_edge(START, "retriever")
123
  builder.add_edge("retriever", "assistant")
124
  builder.add_conditional_edges("assistant", tools_condition)
 
127
 
128
  return builder.compile()
129
 
130
+ # === Test ===
131
  if __name__ == "__main__":
132
+ graph = build_graph("claude")
133
+ result = graph.invoke({"messages": [HumanMessage(content="What is the capital of France?")]})
134
+ for m in result["messages"]:
135
+ m.pretty_print()
136
+
app.py CHANGED
@@ -13,8 +13,8 @@ cached_answers = []
13
 
14
  class ChatAgent:
15
  def __init__(self):
16
- print("ChatAgent initialized with LangGraph workflow.")
17
- self.graph = build_graph()
18
 
19
  def __call__(self, question: str) -> str:
20
  print(f"Processing question: {question[:60]}...")
@@ -95,8 +95,8 @@ with gr.Blocks() as demo:
95
  gr.Markdown("Run the agent on all tasks, then submit for scoring.")
96
  gr.LoginButton()
97
 
98
- run_button = gr.Button("🧠 Run Agent")
99
- submit_button = gr.Button("📤 Submit Answers")
100
 
101
  status_box = gr.Textbox(label="Status", lines=3)
102
  table = gr.DataFrame(label="Results", wrap=True)
 
13
 
14
  class ChatAgent:
15
  def __init__(self):
16
+ print("ChatAgent initialized with Claude LangGraph workflow.")
17
+ self.graph = build_graph("claude")
18
 
19
  def __call__(self, question: str) -> str:
20
  print(f"Processing question: {question[:60]}...")
 
95
  gr.Markdown("Run the agent on all tasks, then submit for scoring.")
96
  gr.LoginButton()
97
 
98
+ run_button = gr.Button("\U0001F9E0 Run Agent")
99
+ submit_button = gr.Button("\U0001F4E4 Submit Answers")
100
 
101
  status_box = gr.Textbox(label="Status", lines=3)
102
  table = gr.DataFrame(label="Results", wrap=True)