mjschock's picture
Enhance graph.py to load prompt templates from local file and configure LiteLLM; add new prompt files for tool calling and code agent
bc3bc22 unverified
raw
history blame
2.8 kB
import logging
from typing import TypedDict
from langgraph.graph import StateGraph, END
from smolagents import ToolCallingAgent, LiteLLMModel
from tools import tools
import yaml
import os
import litellm
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Configure LiteLLM to drop unsupported parameters
litellm.drop_params = True
# Define the state for our agent graph
class AgentState(TypedDict):
messages: list
question: str
answer: str | None
class AgentNode:
def __init__(self):
# Load default prompt templates from local file
current_dir = os.path.dirname(os.path.abspath(__file__))
prompts_dir = os.path.join(current_dir, "prompts")
yaml_path = os.path.join(prompts_dir, "toolcalling_agent.yaml")
with open(yaml_path, 'r') as f:
prompt_templates = yaml.safe_load(f)
# Log the default system prompt
logger.info("Default system prompt:")
logger.info("-" * 80)
logger.info(prompt_templates["system_prompt"])
logger.info("-" * 80)
# # Define our custom system prompt
# custom_system_prompt = "..."
# # Update the system prompt in the loaded templates
# prompt_templates["system_prompt"] = custom_system_prompt
# Log our custom system prompt
# logger.info("Custom system prompt:")
# logger.info("-" * 80)
# logger.info(custom_system_prompt)
# logger.info("-" * 80)
# Initialize the model and agent
self.model = LiteLLMModel(
api_base="http://localhost:11434",
api_key=None,
model_id="ollama/codellama",
)
self.agent = ToolCallingAgent(
model=self.model,
prompt_templates=prompt_templates,
tools=tools
)
def __call__(self, state: AgentState) -> AgentState:
try:
# Process the question through the agent
result = self.agent.run(state["question"])
# Update the state with the answer
state["answer"] = result
return state
except Exception as e:
logger.error(f"Error in agent node: {str(e)}", exc_info=True)
state["answer"] = f"Error: {str(e)}"
return state
def build_agent_graph():
# Create the graph
graph = StateGraph(AgentState)
# Add the agent node
graph.add_node("agent", AgentNode())
# Add edges
graph.add_edge("agent", END)
# Set the entry point
graph.set_entry_point("agent")
# Compile the graph
return graph.compile()
# Create an instance of the compiled graph
agent_graph = build_agent_graph()