mjschock's picture
Refactor main_v2.py to remove unused tools and agents, replacing them with the new SmartSearchTool for improved search functionality. Update prompt template loading to use the modified YAML file. Clean up imports and enhance overall code organization for better maintainability.
a3a55bb unverified
raw
history blame
3.88 kB
import logging
import os
import requests
import yaml
from dotenv import find_dotenv, load_dotenv
from litellm._logging import _disable_debugging
from openinference.instrumentation.smolagents import SmolagentsInstrumentor
from phoenix.otel import register
# from smolagents import CodeAgent, LiteLLMModel, LiteLLMRouterModel
from smolagents import CodeAgent, LiteLLMModel
from smolagents.monitoring import LogLevel
from tools.smart_search.tool import SmartSearchTool
from utils import extract_final_answer
_disable_debugging()
# Configure OpenTelemetry with Phoenix
register()
SmolagentsInstrumentor().instrument()
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
load_dotenv(find_dotenv())
API_BASE = os.getenv("API_BASE")
API_KEY = os.getenv("API_KEY")
MODEL_ID = os.getenv("MODEL_ID")
model = LiteLLMModel(
api_base=API_BASE,
api_key=API_KEY,
model_id=MODEL_ID,
)
# data_agent = create_data_agent(model)
# media_agent = create_media_agent(model)
# web_agent = create_web_agent(model)
# search_agent = ToolCallingAgent(
# tools=[DuckDuckGoSearchTool(), VisitWebpageTool()],
# model=model,
# name="search_agent",
# description="This is an agent that can do web search.",
# )
prompt_templates = yaml.safe_load(open("prompts/code_agent_modified.yaml", "r"))
agent = CodeAgent(
# add_base_tools=True,
# additional_authorized_imports=[
# "json",
# "pandas",
# "numpy",
# "re",
# # "requests"
# # "urllib.request",
# ],
# max_steps=10,
# managed_agents=[web_agent, data_agent, media_agent],
# managed_agents=[search_agent],
model=model,
prompt_templates=prompt_templates,
tools=[
SmartSearchTool(),
# VisitWebpageTool(max_output_length=1024),
],
step_callbacks=None,
verbosity_level=LogLevel.ERROR,
)
agent.visualize()
def main(task: str):
# Format the task with GAIA-style instructions
# gaia_task = f"""Instructions:
# 1. Your response must contain ONLY the answer to the question, nothing else
# 2. Do not repeat the question or any part of it
# 3. Do not include any explanations, reasoning, or context
# 4. Do not include source attribution or references
# 5. Do not use phrases like "The answer is" or "I found that"
# 6. Do not include any formatting, bullet points, or line breaks
# 7. If the answer is a number, return only the number
# 8. If the answer requires multiple items, separate them with commas
# 9. If the answer requires ordering, maintain the specified order
# 10. Use the most direct and succinct form possible
# {task}"""
result = agent.run(
additional_args=None,
images=None,
max_steps=3,
reset=True,
stream=False,
task=task,
# task=gaia_task,
)
logger.info(f"Result: {result}")
return extract_final_answer(result)
if __name__ == "__main__":
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
api_url = DEFAULT_API_URL
questions_url = f"{api_url}/questions"
submit_url = f"{api_url}/submit"
response = requests.get(questions_url, timeout=15)
response.raise_for_status()
questions_data = response.json()
for question_data in questions_data[:1]:
file_name = question_data["file_name"]
level = question_data["Level"]
question = question_data["question"]
task_id = question_data["task_id"]
logger.info(f"Question: {question}")
# logger.info(f"Level: {level}")
if file_name:
logger.info(f"File Name: {file_name}")
# logger.info(f"Task ID: {task_id}")
final_answer = main(question)
logger.info(f"Final Answer: {final_answer}")
logger.info("--------------------------------")