File size: 3,879 Bytes
a60b872
 
 
 
 
 
 
 
 
 
 
a3a55bb
a60b872
 
a3a55bb
a60b872
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2579190
 
 
a60b872
2579190
 
 
 
 
 
 
a3a55bb
a60b872
 
 
2579190
 
 
 
 
 
 
 
a60b872
 
2579190
a60b872
 
 
a3a55bb
 
a60b872
 
 
 
 
 
 
 
 
2aa9dc2
2579190
 
 
 
 
 
 
 
 
 
 
 
 
2aa9dc2
a60b872
 
 
2579190
a60b872
 
2579190
 
a60b872
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import logging
import os

import requests
import yaml
from dotenv import find_dotenv, load_dotenv
from litellm._logging import _disable_debugging
from openinference.instrumentation.smolagents import SmolagentsInstrumentor
from phoenix.otel import register

# from smolagents import CodeAgent, LiteLLMModel, LiteLLMRouterModel
from smolagents import CodeAgent, LiteLLMModel
from smolagents.monitoring import LogLevel

from tools.smart_search.tool import SmartSearchTool
from utils import extract_final_answer

_disable_debugging()

# Configure OpenTelemetry with Phoenix
register()
SmolagentsInstrumentor().instrument()

logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

load_dotenv(find_dotenv())

API_BASE = os.getenv("API_BASE")
API_KEY = os.getenv("API_KEY")
MODEL_ID = os.getenv("MODEL_ID")

model = LiteLLMModel(
    api_base=API_BASE,
    api_key=API_KEY,
    model_id=MODEL_ID,
)

# data_agent = create_data_agent(model)
# media_agent = create_media_agent(model)
# web_agent = create_web_agent(model)

# search_agent = ToolCallingAgent(
#     tools=[DuckDuckGoSearchTool(), VisitWebpageTool()],
#     model=model,
#     name="search_agent",
#     description="This is an agent that can do web search.",
# )

prompt_templates = yaml.safe_load(open("prompts/code_agent_modified.yaml", "r"))

agent = CodeAgent(
    # add_base_tools=True,
    # additional_authorized_imports=[
    #     "json",
    #     "pandas",
    #     "numpy",
    #     "re",
    #     # "requests"
    #     # "urllib.request",
    # ],
    # max_steps=10,
    # managed_agents=[web_agent, data_agent, media_agent],
    # managed_agents=[search_agent],
    model=model,
    prompt_templates=prompt_templates,
    tools=[
        SmartSearchTool(),
        # VisitWebpageTool(max_output_length=1024),
    ],
    step_callbacks=None,
    verbosity_level=LogLevel.ERROR,
)

agent.visualize()


def main(task: str):
    # Format the task with GAIA-style instructions
#     gaia_task = f"""Instructions:
# 1. Your response must contain ONLY the answer to the question, nothing else
# 2. Do not repeat the question or any part of it
# 3. Do not include any explanations, reasoning, or context
# 4. Do not include source attribution or references
# 5. Do not use phrases like "The answer is" or "I found that"
# 6. Do not include any formatting, bullet points, or line breaks
# 7. If the answer is a number, return only the number
# 8. If the answer requires multiple items, separate them with commas
# 9. If the answer requires ordering, maintain the specified order
# 10. Use the most direct and succinct form possible

# {task}"""

    result = agent.run(
        additional_args=None,
        images=None,
        max_steps=3,
        reset=True,
        stream=False,
        task=task,
        # task=gaia_task,
    )

    logger.info(f"Result: {result}")

    return extract_final_answer(result)


if __name__ == "__main__":
    DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"

    api_url = DEFAULT_API_URL
    questions_url = f"{api_url}/questions"
    submit_url = f"{api_url}/submit"

    response = requests.get(questions_url, timeout=15)
    response.raise_for_status()
    questions_data = response.json()

    for question_data in questions_data[:1]:
        file_name = question_data["file_name"]
        level = question_data["Level"]
        question = question_data["question"]
        task_id = question_data["task_id"]

        logger.info(f"Question: {question}")
        # logger.info(f"Level: {level}")
        if file_name:
            logger.info(f"File Name: {file_name}")
        # logger.info(f"Task ID: {task_id}")

        final_answer = main(question)
        logger.info(f"Final Answer: {final_answer}")
        logger.info("--------------------------------")