mjschock's picture
Refactor agent structure by modularizing agent implementations into separate directories for web, data analysis, and media agents. Remove legacy code from agents.py, prompts.py, and tools.py, enhancing maintainability. Update main_v2.py to reflect new import paths and agent initialization. Add new tools for enhanced functionality, including web searching and data extraction. Update requirements.txt to include necessary dependencies for new tools.
837e221 unverified
raw
history blame
3.15 kB
import importlib
import logging
import os
import requests
import yaml
from dotenv import find_dotenv, load_dotenv
from litellm._logging import _disable_debugging
from openinference.instrumentation.smolagents import SmolagentsInstrumentor
from phoenix.otel import register
# from smolagents import CodeAgent, LiteLLMModel, LiteLLMRouterModel
from smolagents import CodeAgent, LiteLLMModel
from smolagents.default_tools import DuckDuckGoSearchTool, VisitWebpageTool
from smolagents.monitoring import LogLevel
from agents.data_agent.agent import create_data_agent
from agents.media_agent.agent import create_media_agent
from agents.web_agent.agent import create_web_agent
from utils import extract_final_answer
_disable_debugging()
# Configure OpenTelemetry with Phoenix
register()
SmolagentsInstrumentor().instrument()
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
load_dotenv(find_dotenv())
API_BASE = os.getenv("API_BASE")
API_KEY = os.getenv("API_KEY")
MODEL_ID = os.getenv("MODEL_ID")
model = LiteLLMModel(
api_base=API_BASE,
api_key=API_KEY,
model_id=MODEL_ID,
)
data_agent = create_data_agent(model)
media_agent = create_media_agent(model)
web_agent = create_web_agent(model)
prompt_templates = yaml.safe_load(
importlib.resources.files("smolagents.prompts")
.joinpath("code_agent.yaml")
.read_text()
)
agent = CodeAgent(
# add_base_tools=True,
additional_authorized_imports=[
"json",
"pandas",
"numpy",
"re",
# "requests"
# "urllib.request",
],
# max_steps=10,
# managed_agents=[web_agent, data_agent, media_agent],
model=model,
prompt_templates=prompt_templates,
tools=[
DuckDuckGoSearchTool(max_results=1),
VisitWebpageTool(max_output_length=256),
],
step_callbacks=None,
verbosity_level=LogLevel.ERROR,
)
agent.visualize()
def main(task: str):
result = agent.run(
additional_args=None,
images=None,
max_steps=3,
reset=True,
stream=False,
task=task,
)
logger.info(f"Result: {result}")
return extract_final_answer(result)
if __name__ == "__main__":
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
api_url = DEFAULT_API_URL
questions_url = f"{api_url}/questions"
submit_url = f"{api_url}/submit"
response = requests.get(questions_url, timeout=15)
response.raise_for_status()
questions_data = response.json()
for question_data in questions_data[:1]:
file_name = question_data["file_name"]
level = question_data["Level"]
question = question_data["question"]
task_id = question_data["task_id"]
logger.info(f"Question: {question}")
# logger.info(f"Level: {level}")
if file_name:
logger.info(f"File Name: {file_name}")
# logger.info(f"Task ID: {task_id}")
final_answer = main(question)
logger.info(f"Final Answer: {final_answer}")
logger.info("--------------------------------")