File size: 3,562 Bytes
912f746
 
 
 
 
fb8728b
 
 
 
 
652eb00
fb8728b
 
912f746
 
652eb00
 
 
912f746
 
652eb00
912f746
 
0866aba
 
912f746
 
 
 
 
 
652eb00
fb8728b
 
 
 
652eb00
 
fb8728b
 
912f746
fb8728b
912f746
 
fb8728b
 
 
 
 
 
 
652eb00
 
 
 
fb8728b
 
 
 
912f746
 
4a8d3f6
 
 
912f746
 
 
 
652eb00
912f746
 
 
 
 
 
 
 
 
 
652eb00
 
912f746
fb8728b
652eb00
912f746
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
652eb00
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os

import PIL.Image
from dotenv import load_dotenv
from loguru import logger
from smolagents import (
    AzureOpenAIServerModel,
    CodeAgent,
    GoogleSearchTool,
    PythonInterpreterTool,
    SpeechToTextTool,
    VisitWebpageTool,
)

from src.file_handler.parse import parse_file
from src.tools.reasoning import ReasoningToolkit
from src.tools.reverse_question import reverse_question
from src.tracing import add_tracing

load_dotenv()
add_tracing()


class Agent:
    def __init__(self):
        model = AzureOpenAIServerModel(
            model_id=os.getenv("AZURE_OPENAI_MODEL_ID"),
            azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
            api_key=os.getenv("AZURE_OPENAI_API_KEY"),
            api_version=os.getenv("OPENAI_API_VERSION"),
        )
        reasoning_toolkit = ReasoningToolkit()
        tools = [
            GoogleSearchTool(provider="serper"),
            VisitWebpageTool(),
            PythonInterpreterTool(),
            SpeechToTextTool(),
            *reasoning_toolkit.tools,
            reverse_question,
        ]
        self.agent = CodeAgent(
            tools=tools,
            model=model,
        )
        self.user_prompt = """
        I will ask you a question.
        YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
        If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
        If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
        If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.

        You MUST use the following tools:
        - think, used before all other tool call and before the final answer
        - analyze, used after all other tool call and before the final answer

        Question: {question}

        Attached content: {content}
        """
        logger.info("BasicAgent initialized.")

    def __call__(
        self, question: str, task_id: str, file_name: str, api_url: str
    ) -> str:
        logger.info(
            f"Agent received question (first 50 chars): {question[:50]}..."
        )
        images = None
        content = ""

        if file_name:
            content = parse_file(task_id, file_name, api_url)
            if content:
                if isinstance(
                    content, PIL.Image.Image
                ):  # Parse content as image
                    images = [content]
                else:  # Append content to question
                    logger.info(f"Question with content: {question}")

        prompt = self.user_prompt.format(question=question, content=content)

        answer = self.agent.run(prompt, images=images)
        answer = str(answer).replace("FINAL ANSWER:", "").strip()
        logger.info(f"Agent returning answer: {answer}")
        return answer


if __name__ == "__main__":
    import requests

    api_url = "https://agents-course-unit4-scoring.hf.space"
    question_url = f"{api_url}/random-question"

    data = requests.get(question_url).json()
    agent = Agent()

    task_id = data["task_id"]
    question = data["question"]
    file_name = data["file_name"]
    logger.info(
        f"Task ID: {task_id}\nQuestion: {question}\nFile Name: {file_name}\n\n"
    )

    answer = agent(question, task_id, file_name, api_url)