Spaces:
Build error
Build error
Refactor main_v2.py to remove unused tools and agents, replacing them with the new SmartSearchTool for improved search functionality. Update prompt template loading to use the modified YAML file. Clean up imports and enhance overall code organization for better maintainability.
a3a55bb
unverified
import logging | |
import os | |
import requests | |
import yaml | |
from dotenv import find_dotenv, load_dotenv | |
from litellm._logging import _disable_debugging | |
from openinference.instrumentation.smolagents import SmolagentsInstrumentor | |
from phoenix.otel import register | |
# from smolagents import CodeAgent, LiteLLMModel, LiteLLMRouterModel | |
from smolagents import CodeAgent, LiteLLMModel | |
from smolagents.monitoring import LogLevel | |
from tools.smart_search.tool import SmartSearchTool | |
from utils import extract_final_answer | |
_disable_debugging() | |
# Configure OpenTelemetry with Phoenix | |
register() | |
SmolagentsInstrumentor().instrument() | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
) | |
logger = logging.getLogger(__name__) | |
load_dotenv(find_dotenv()) | |
API_BASE = os.getenv("API_BASE") | |
API_KEY = os.getenv("API_KEY") | |
MODEL_ID = os.getenv("MODEL_ID") | |
model = LiteLLMModel( | |
api_base=API_BASE, | |
api_key=API_KEY, | |
model_id=MODEL_ID, | |
) | |
# data_agent = create_data_agent(model) | |
# media_agent = create_media_agent(model) | |
# web_agent = create_web_agent(model) | |
# search_agent = ToolCallingAgent( | |
# tools=[DuckDuckGoSearchTool(), VisitWebpageTool()], | |
# model=model, | |
# name="search_agent", | |
# description="This is an agent that can do web search.", | |
# ) | |
prompt_templates = yaml.safe_load(open("prompts/code_agent_modified.yaml", "r")) | |
agent = CodeAgent( | |
# add_base_tools=True, | |
# additional_authorized_imports=[ | |
# "json", | |
# "pandas", | |
# "numpy", | |
# "re", | |
# # "requests" | |
# # "urllib.request", | |
# ], | |
# max_steps=10, | |
# managed_agents=[web_agent, data_agent, media_agent], | |
# managed_agents=[search_agent], | |
model=model, | |
prompt_templates=prompt_templates, | |
tools=[ | |
SmartSearchTool(), | |
# VisitWebpageTool(max_output_length=1024), | |
], | |
step_callbacks=None, | |
verbosity_level=LogLevel.ERROR, | |
) | |
agent.visualize() | |
def main(task: str): | |
# Format the task with GAIA-style instructions | |
# gaia_task = f"""Instructions: | |
# 1. Your response must contain ONLY the answer to the question, nothing else | |
# 2. Do not repeat the question or any part of it | |
# 3. Do not include any explanations, reasoning, or context | |
# 4. Do not include source attribution or references | |
# 5. Do not use phrases like "The answer is" or "I found that" | |
# 6. Do not include any formatting, bullet points, or line breaks | |
# 7. If the answer is a number, return only the number | |
# 8. If the answer requires multiple items, separate them with commas | |
# 9. If the answer requires ordering, maintain the specified order | |
# 10. Use the most direct and succinct form possible | |
# {task}""" | |
result = agent.run( | |
additional_args=None, | |
images=None, | |
max_steps=3, | |
reset=True, | |
stream=False, | |
task=task, | |
# task=gaia_task, | |
) | |
logger.info(f"Result: {result}") | |
return extract_final_answer(result) | |
if __name__ == "__main__": | |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" | |
api_url = DEFAULT_API_URL | |
questions_url = f"{api_url}/questions" | |
submit_url = f"{api_url}/submit" | |
response = requests.get(questions_url, timeout=15) | |
response.raise_for_status() | |
questions_data = response.json() | |
for question_data in questions_data[:1]: | |
file_name = question_data["file_name"] | |
level = question_data["Level"] | |
question = question_data["question"] | |
task_id = question_data["task_id"] | |
logger.info(f"Question: {question}") | |
# logger.info(f"Level: {level}") | |
if file_name: | |
logger.info(f"File Name: {file_name}") | |
# logger.info(f"Task ID: {task_id}") | |
final_answer = main(question) | |
logger.info(f"Final Answer: {final_answer}") | |
logger.info("--------------------------------") | |