Spaces:
Running
Running
from typing import TypedDict, Annotated, Sequence | |
from langchain_core.messages import BaseMessage, HumanMessage | |
from langchain_core.tools import tool | |
from langchain_openai import ChatOpenAI | |
from langgraph.graph import END, StateGraph | |
from langgraph.prebuilt import ToolNode | |
from langchain.tools import DuckDuckGoSearchResults | |
from langchain_community.utilities import WikipediaAPIWrapper | |
from langchain.agents import create_tool_calling_agent | |
from langchain.agents import AgentExecutor | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
import operator | |
import json | |
load dotenv() | |
# Define the agent state | |
class AgentState(TypedDict): | |
messages: Annotated[Sequence[BaseMessage], operator.add] | |
sender: str | |
# Initialize tools | |
def wikipedia_search(query: str) -> str: | |
"""Search Wikipedia for information.""" | |
return WikipediaAPIWrapper().run(query) | |
def web_search(query: str, num_results: int = 3) -> list: | |
"""Search the web for current information.""" | |
return DuckDuckGoSearchResults(num_results=num_results).run(query) | |
def calculate(expression: str) -> str: | |
"""Evaluate mathematical expressions.""" | |
from langchain.chains import LLMMathChain | |
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0) | |
return LLMMathChain(llm=llm).run(expression) | |
class AdvancedAIAgent: | |
def __init__(self, model_name="gpt-4-turbo"): | |
# Initialize tools and LLM | |
self.tools = [wikipedia_search, web_search, calculate] | |
self.llm = ChatOpenAI(model=model_name, temperature=0.7) | |
# Create the agent | |
self.agent = self._create_agent() | |
# Build the graph workflow | |
self.workflow = self._build_graph() | |
def _create_agent(self) -> AgentExecutor: | |
"""Create the agent with tools and prompt""" | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", "You are a helpful AI assistant. Use tools when needed."), | |
MessagesPlaceholder(variable_name="messages"), | |
MessagesPlaceholder(variable_name="agent_scratchpad"), | |
]) | |
agent = create_tool_calling_agent(self.llm, self.tools, prompt) | |
return AgentExecutor(agent=agent, tools=self.tools, verbose=True) | |
def _build_graph(self): | |
"""Build the LangGraph workflow""" | |
workflow = StateGraph(AgentState) | |
# Define nodes | |
workflow.add_node("agent", self._call_agent) | |
workflow.add_node("tools", ToolNode(self.tools)) # Using ToolNode instead of ToolExecutor | |
# Define edges | |
workflow.set_entry_point("agent") | |
workflow.add_conditional_edges( | |
"agent", | |
self._should_continue, | |
{ | |
"continue": "tools", | |
"end": END | |
} | |
) | |
workflow.add_edge("tools", "agent") | |
return workflow.compile() | |
def _call_agent(self, state: AgentState): | |
"""Execute the agent""" | |
response = self.agent.invoke({"messages": state["messages"]}) | |
return {"messages": [response["output"]]} | |
def _should_continue(self, state: AgentState): | |
"""Determine if the workflow should continue""" | |
last_message = state["messages"][-1] | |
# If no tool calls, end | |
if not last_message.additional_kwargs.get("tool_calls"): | |
return "end" | |
return "continue" | |
def __call__(self, query: str) -> dict: | |
"""Process a user query""" | |
# Initialize state | |
state = AgentState(messages=[HumanMessage(content=query)], sender="user") | |
# Execute the workflow | |
for output in self.workflow.stream(state): | |
for key, value in output.items(): | |
if key == "messages": | |
for message in value: | |
if isinstance(message, BaseMessage): | |
return { | |
"response": message.content, | |
"sources": self._extract_sources(state["messages"]), | |
"steps": self._extract_steps(state["messages"]) | |
} | |
def _extract_sources(self, messages: Sequence[BaseMessage]) -> list: | |
"""Extract sources from tool messages""" | |
return [ | |
f"{msg.additional_kwargs.get('name', 'unknown')}: {msg.content}" | |
for msg in messages | |
if hasattr(msg, 'additional_kwargs') and 'name' in msg.additional_kwargs | |
] | |
def _extract_steps(self, messages: Sequence[BaseMessage]) -> list: | |
"""Extract reasoning steps""" | |
steps = [] | |
for msg in messages: | |
if hasattr(msg, 'additional_kwargs') and 'tool_calls' in msg.additional_kwargs: | |
for call in msg.additional_kwargs['tool_calls']: | |
steps.append(f"Used {call['function']['name']}: {call['function']['arguments']}") | |
return steps |