Spaces:
Running
Running
from typing import TypedDict, Annotated, Sequence | |
from langchain_core.messages import BaseMessage, HumanMessage | |
from langchain_core.tools import tool | |
from langchain_openai import ChatOpenAI | |
from langgraph.graph import END, StateGraph | |
from langgraph.prebuilt import ToolExecutor, ToolInvocation | |
from langchain.tools import DuckDuckGoSearchResults | |
from langchain_community.utilities import WikipediaAPIWrapper | |
from langchain.agents import create_tool_calling_agent | |
from langchain.agents import AgentExecutor | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
import operator | |
load_dotenv() | |
# Define the agent state | |
class AgentState(TypedDict): | |
messages: Annotated[Sequence[BaseMessage], operator.add] | |
sender: str | |
# Initialize tools | |
def wikipedia_search(query: str) -> str: | |
"""Search Wikipedia for information.""" | |
return WikipediaAPIWrapper().run(query) | |
def web_search(query: str, num_results: int = 3) -> list: | |
"""Search the web for current information.""" | |
return DuckDuckGoSearchResults(num_results=num_results).run(query) | |
def calculate(expression: str) -> str: | |
"""Evaluate mathematical expressions.""" | |
from langchain.chains import LLMMathChain | |
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0) | |
return LLMMathChain(llm=llm).run(expression) | |
class AdvancedAIAgent: | |
def __init__(self, model_name="gpt-4-turbo"): | |
# Initialize tools and LLM | |
self.tools = [wikipedia_search, web_search, calculate] | |
self.llm = ChatOpenAI(model=model_name, temperature=0.7) | |
# Create the agent | |
self.agent = self._create_agent() | |
self.tool_executor = ToolExecutor(self.tools) | |
# Build the graph workflow | |
self.workflow = self._build_graph() | |
def _create_agent(self) -> AgentExecutor: | |
"""Create the agent with tools and prompt""" | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", "You are a helpful AI assistant. Use tools when needed."), | |
MessagesPlaceholder(variable_name="messages"), | |
MessagesPlaceholder(variable_name="agent_scratchpad"), | |
]) | |
agent = create_tool_calling_agent(self.llm, self.tools, prompt) | |
return AgentExecutor(agent=agent, tools=self.tools, verbose=True) | |
def _build_graph(self): | |
"""Build the LangGraph workflow""" | |
workflow = StateGraph(AgentState) | |
# Define nodes | |
workflow.add_node("agent", self._call_agent) | |
workflow.add_node("tools", self._call_tools) | |
# Define edges | |
workflow.set_entry_point("agent") | |
workflow.add_conditional_edges( | |
"agent", | |
self._should_continue, | |
{ | |
"continue": "tools", | |
"end": END | |
} | |
) | |
workflow.add_edge("tools", "agent") | |
return workflow.compile() | |
def _call_agent(self, state: AgentState): | |
"""Execute the agent""" | |
response = self.agent.invoke({"messages": state["messages"]}) | |
return {"messages": [response["output"]]} | |
def _call_tools(self, state: AgentState): | |
"""Execute tools""" | |
last_message = state["messages"][-1] | |
# Find the tool calls | |
tool_calls = last_message.additional_kwargs.get("tool_calls", []) | |
# Execute each tool | |
for tool_call in tool_calls: | |
action = ToolInvocation( | |
tool=tool_call["function"]["name"], | |
tool_input=json.loads(tool_call["function"]["arguments"]), | |
) | |
output = self.tool_executor.invoke(action) | |
# Create tool message | |
tool_message = ToolMessage( | |
content=str(output), | |
name=action.tool, | |
tool_call_id=tool_call["id"], | |
) | |
state["messages"].append(tool_message) | |
return {"messages": state["messages"]} | |
def _should_continue(self, state: AgentState): | |
"""Determine if the workflow should continue""" | |
last_message = state["messages"][-1] | |
# If no tool calls, end | |
if not last_message.additional_kwargs.get("tool_calls"): | |
return "end" | |
return "continue" | |
def __call__(self, query: str) -> dict: | |
"""Process a user query""" | |
# Initialize state | |
state = AgentState(messages=[HumanMessage(content=query)], sender="user") | |
# Execute the workflow | |
for output in self.workflow.stream(state): | |
for key, value in output.items(): | |
if key == "messages": | |
for message in value: | |
if isinstance(message, BaseMessage): | |
return { | |
"response": message.content, | |
"sources": self._extract_sources(state["messages"]), | |
"steps": self._extract_steps(state["messages"]) | |
} | |
def _extract_sources(self, messages: Sequence[BaseMessage]) -> list: | |
"""Extract sources from tool messages""" | |
return [ | |
f"{msg.additional_kwargs.get('name', 'unknown')}: {msg.content}" | |
for msg in messages | |
if hasattr(msg, 'additional_kwargs') and 'name' in msg.additional_kwargs | |
] | |
def _extract_steps(self, messages: Sequence[BaseMessage]) -> list: | |
"""Extract reasoning steps""" | |
steps = [] | |
for msg in messages: | |
if hasattr(msg, 'additional_kwargs') and 'tool_calls' in msg.additional_kwargs: | |
for call in msg.additional_kwargs['tool_calls']: | |
steps.append(f"Used {call['function']['name']}: {call['function']['arguments']}") | |
return steps | |
# Example usage | |
if __name__ == "__main__": | |
agent = AdvancedAIAgent() | |
queries = [ | |
"What is the capital of France?", | |
"Calculate 15% of 200", | |
"Tell me about the latest developments in quantum computing" | |
] | |
for query in queries: | |
print(f"\nQuestion: {query}") | |
response = agent(query) | |
print(f"Answer: {response['response']}") | |
if response['sources']: | |
print("Sources:") | |
for source in response['sources']: | |
print(f"- {source}") | |
if response['steps']: | |
print("Steps taken:") | |
for step in response['steps']: | |
print(f"- {step}") |