wt002 commited on
Commit
b49b95b
·
verified ·
1 Parent(s): 8f90b3d

Update agent.py

Browse files
Files changed (1) hide show
  1. app.py +110 -42
app.py CHANGED
@@ -3,8 +3,7 @@ import gradio as gr
3
  import requests
4
  import inspect
5
  import pandas as pd
6
- from langchain_core.messages import HumanMessage
7
- from typing import TypedDict, Annotated, Sequence, Dict, Any, List
8
  from langchain_core.messages import BaseMessage, HumanMessage
9
  from langchain_core.tools import tool
10
  from langchain_openai import ChatOpenAI
@@ -15,6 +14,9 @@ from langchain_community.utilities import WikipediaAPIWrapper
15
  from langchain.agents import create_tool_calling_agent, AgentExecutor
16
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
17
  import operator
 
 
 
18
 
19
  # (Keep Constants as is)
20
  # --- Constants ---
@@ -24,44 +26,97 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
24
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  class AgentState(TypedDict):
28
  messages: Annotated[Sequence[BaseMessage], operator.add]
29
  sender: str
30
 
31
  @tool
32
  def wikipedia_search(query: str) -> str:
33
- """Search Wikipedia for information."""
34
- return WikipediaAPIWrapper().run(query)
 
 
 
35
 
36
  @tool
37
  def web_search(query: str, num_results: int = 3) -> list:
38
- """Search the web for current information."""
39
- return DuckDuckGoSearchResults(num_results=num_results).run(query)
 
 
 
 
40
 
41
  @tool
42
  def calculate(expression: str) -> str:
43
- """Evaluate mathematical expressions."""
44
- from langchain_experimental.utilities import PythonREPL
45
- python_repl = PythonREPL()
46
- return python_repl.run(expression)
 
 
 
47
 
48
  class BasicAgent:
49
- """A langgraph agent."""
50
- def __init__(self):
51
- print("BasicAgent initialized.")
52
- self.graph = build_graph()
53
-
54
- def __call__(self, question: str) -> str:
55
- print(f"Agent received question (first 50 chars): {question[:50]}...")
56
- # Wrap the question in a HumanMessage from langchain_core
57
- messages = [HumanMessage(content=question)]
58
- messages = self.graph.invoke({"messages": messages})
59
- answer = messages['messages'][-1].content
60
- return answer[14:]
61
-
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def _build_workflow(self) -> StateGraph:
64
- """Build and return the compiled workflow"""
65
  workflow = StateGraph(AgentState)
66
 
67
  workflow.add_node("agent", self._run_agent)
@@ -76,19 +131,25 @@ class BasicAgent:
76
  workflow.add_edge("tools", "agent")
77
 
78
  return workflow.compile()
79
-
 
80
  def _run_agent(self, state: AgentState) -> Dict[str, Any]:
81
- """Execute the agent"""
82
  response = self.agent_executor.invoke({"messages": state["messages"]})
83
  return {"messages": [response["output"]]}
84
-
85
  def _should_continue(self, state: AgentState) -> str:
86
- """Determine if the workflow should continue"""
87
  last_message = state["messages"][-1]
88
  return "continue" if last_message.additional_kwargs.get("tool_calls") else "end"
89
-
 
90
  def __call__(self, query: str) -> Dict[str, Any]:
91
- """Process a user query"""
 
 
 
 
92
  state = AgentState(messages=[HumanMessage(content=query)], sender="user")
93
 
94
  for output in self.workflow.stream(state):
@@ -96,29 +157,36 @@ class BasicAgent:
96
  if key == "messages":
97
  for message in value:
98
  if isinstance(message, BaseMessage):
 
99
  return {
100
- "response": message.content,
101
  "sources": self._extract_sources(state["messages"]),
102
- "steps": self._extract_steps(state["messages"])
 
103
  }
104
  return {"response": "No response generated", "sources": [], "steps": []}
105
-
106
  def _extract_sources(self, messages: Sequence[BaseMessage]) -> List[str]:
107
- """Extract sources from tool messages"""
108
- return [
109
- f"{msg.additional_kwargs.get('name', 'unknown')}: {msg.content}"
110
- for msg in messages
111
- if hasattr(msg, 'additional_kwargs') and 'name' in msg.additional_kwargs
112
- ]
113
-
 
 
114
  def _extract_steps(self, messages: Sequence[BaseMessage]) -> List[str]:
115
- """Extract reasoning steps"""
116
  steps = []
117
  for msg in messages:
118
  if hasattr(msg, 'additional_kwargs') and 'tool_calls' in msg.additional_kwargs:
119
  for call in msg.additional_kwargs['tool_calls']:
120
- steps.append(f"Used {call['function']['name']}: {call['function']['arguments']}")
 
 
121
  return steps
 
122
 
123
  def run_and_submit_all( profile: gr.OAuthProfile | None):
124
  """
 
3
  import requests
4
  import inspect
5
  import pandas as pd
6
+ from typing import TypedDict, Annotated, Sequence, Dict, Any, List, Optional
 
7
  from langchain_core.messages import BaseMessage, HumanMessage
8
  from langchain_core.tools import tool
9
  from langchain_openai import ChatOpenAI
 
14
  from langchain.agents import create_tool_calling_agent, AgentExecutor
15
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
16
  import operator
17
+ from langchain_experimental.utilities import PythonREPL
18
+ from functools import wraps
19
+ import logging
20
 
21
  # (Keep Constants as is)
22
  # --- Constants ---
 
26
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
27
 
28
 
29
+
30
+ # --- Configure logging ---
31
+ logging.basicConfig(level=logging.INFO)
32
+ logger = logging.getLogger(__name__)
33
+
34
+ # --- Constants ---
35
+ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
36
+ DEFAULT_MODEL = "gpt-3.5-turbo"
37
+ MAX_RESPONSE_LENGTH = 2000 # Prevent overly long responses
38
+
39
+ def handle_errors(func):
40
+ """Decorator to handle common errors in agent operations."""
41
+ @wraps(func)
42
+ def wrapper(*args, **kwargs):
43
+ try:
44
+ return func(*args, **kwargs)
45
+ except Exception as e:
46
+ logger.error(f"Error in {func.__name__}: {str(e)}")
47
+ return {"error": str(e)}
48
+ return wrapper
49
+
50
  class AgentState(TypedDict):
51
  messages: Annotated[Sequence[BaseMessage], operator.add]
52
  sender: str
53
 
54
  @tool
55
  def wikipedia_search(query: str) -> str:
56
+ """Search Wikipedia for information. Useful for historical facts, scientific concepts, and general knowledge."""
57
+ try:
58
+ return WikipediaAPIWrapper().run(query)[:MAX_RESPONSE_LENGTH]
59
+ except Exception as e:
60
+ return f"Wikipedia search failed: {str(e)}"
61
 
62
  @tool
63
  def web_search(query: str, num_results: int = 3) -> list:
64
+ """Search the web for current information. Useful for news, recent events, and up-to-date data."""
65
+ try:
66
+ results = DuckDuckGoSearchResults(num_results=num_results).run(query)
67
+ return [str(r)[:500] for r in results][:num_results] # Limit result size
68
+ except Exception as e:
69
+ return [f"Web search failed: {str(e)}"]
70
 
71
  @tool
72
  def calculate(expression: str) -> str:
73
+ """Evaluate mathematical expressions. Supports basic arithmetic and complex formulas."""
74
+ try:
75
+ python_repl = PythonREPL()
76
+ result = python_repl.run(expression)
77
+ return str(result)[:100] # Limit numeric output length
78
+ except Exception as e:
79
+ return f"Calculation failed: {str(e)}"
80
 
81
  class BasicAgent:
82
+ """An enhanced LangGraph agent with better error handling and response processing."""
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ def __init__(self, model_name: str = DEFAULT_MODEL, temperature: float = 0.7):
85
+ """Initialize the agent with tools and workflow."""
86
+ self.model_name = model_name
87
+ self.temperature = temperature
88
+ self.tools = [wikipedia_search, web_search, calculate]
89
+ self.llm = ChatOpenAI(model=model_name, temperature=temperature)
90
+ self.agent_executor = self._build_agent_executor()
91
+ self.workflow = self._build_workflow()
92
+ logger.info(f"AdvancedAgent initialized with model: {model_name}")
93
+
94
+ def _build_agent_executor(self) -> AgentExecutor:
95
+ """Build the agent executor with proper prompt and tools."""
96
+ prompt = ChatPromptTemplate.from_messages([
97
+ ("system", self._get_system_prompt()),
98
+ MessagesPlaceholder(variable_name="messages"),
99
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
100
+ ])
101
+ agent = create_tool_calling_agent(self.llm, self.tools, prompt)
102
+ return AgentExecutor(
103
+ agent=agent,
104
+ tools=self.tools,
105
+ verbose=True,
106
+ handle_parsing_errors=True
107
+ )
108
+
109
+ def _get_system_prompt(self) -> str:
110
+ """Return a comprehensive system prompt for the agent."""
111
+ return """You are an advanced AI assistant with access to tools. Follow these rules:
112
+ 1. Be precise and factual
113
+ 2. Use tools when needed
114
+ 3. Cite your sources
115
+ 4. Break complex problems into steps
116
+ 5. Admit when you don't know something"""
117
+
118
  def _build_workflow(self) -> StateGraph:
119
+ """Build and compile the agent workflow."""
120
  workflow = StateGraph(AgentState)
121
 
122
  workflow.add_node("agent", self._run_agent)
 
131
  workflow.add_edge("tools", "agent")
132
 
133
  return workflow.compile()
134
+
135
+ @handle_errors
136
  def _run_agent(self, state: AgentState) -> Dict[str, Any]:
137
+ """Execute the agent with error handling."""
138
  response = self.agent_executor.invoke({"messages": state["messages"]})
139
  return {"messages": [response["output"]]}
140
+
141
  def _should_continue(self, state: AgentState) -> str:
142
+ """Determine if the workflow should continue based on tool calls."""
143
  last_message = state["messages"][-1]
144
  return "continue" if last_message.additional_kwargs.get("tool_calls") else "end"
145
+
146
+ @handle_errors
147
  def __call__(self, query: str) -> Dict[str, Any]:
148
+ """Process a user query and return a structured response."""
149
+ if not query or len(query.strip()) == 0:
150
+ return {"error": "Empty query provided"}
151
+
152
+ logger.info(f"Processing query: {query[:50]}...")
153
  state = AgentState(messages=[HumanMessage(content=query)], sender="user")
154
 
155
  for output in self.workflow.stream(state):
 
157
  if key == "messages":
158
  for message in value:
159
  if isinstance(message, BaseMessage):
160
+ response = message.content[:MAX_RESPONSE_LENGTH]
161
  return {
162
+ "response": response,
163
  "sources": self._extract_sources(state["messages"]),
164
+ "steps": self._extract_steps(state["messages"]),
165
+ "model": self.model_name
166
  }
167
  return {"response": "No response generated", "sources": [], "steps": []}
168
+
169
  def _extract_sources(self, messages: Sequence[BaseMessage]) -> List[str]:
170
+ """Extract and format sources from tool messages."""
171
+ sources = []
172
+ for msg in messages:
173
+ if hasattr(msg, 'additional_kwargs') and 'name' in msg.additional_kwargs:
174
+ source_name = msg.additional_kwargs.get('name', 'unknown')
175
+ content = str(msg.content)[:200] # Truncate long content
176
+ sources.append(f"{source_name}: {content}")
177
+ return sources
178
+
179
  def _extract_steps(self, messages: Sequence[BaseMessage]) -> List[str]:
180
+ """Extract and format the reasoning steps."""
181
  steps = []
182
  for msg in messages:
183
  if hasattr(msg, 'additional_kwargs') and 'tool_calls' in msg.additional_kwargs:
184
  for call in msg.additional_kwargs['tool_calls']:
185
+ tool_name = call['function']['name']
186
+ args = call['function']['arguments'][:100] # Truncate long args
187
+ steps.append(f"Used {tool_name} with args: {args}")
188
  return steps
189
+
190
 
191
  def run_and_submit_all( profile: gr.OAuthProfile | None):
192
  """