|
import os |
|
|
|
import PIL.Image |
|
from dotenv import load_dotenv |
|
from loguru import logger |
|
from smolagents import ( |
|
AzureOpenAIServerModel, |
|
CodeAgent, |
|
GoogleSearchTool, |
|
PythonInterpreterTool, |
|
SpeechToTextTool, |
|
VisitWebpageTool, |
|
) |
|
|
|
from src.file_handler.parse import parse_file |
|
from src.tools.reasoning import ReasoningToolkit |
|
from src.tools.reverse_question import reverse_question |
|
from src.tracing import add_tracing |
|
|
|
load_dotenv() |
|
add_tracing() |
|
|
|
|
|
class Agent: |
|
def __init__(self): |
|
model = AzureOpenAIServerModel( |
|
model_id=os.getenv("AZURE_OPENAI_MODEL_ID"), |
|
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), |
|
api_key=os.getenv("AZURE_OPENAI_API_KEY"), |
|
api_version=os.getenv("OPENAI_API_VERSION"), |
|
) |
|
reasoning_toolkit = ReasoningToolkit() |
|
tools = [ |
|
GoogleSearchTool(provider="serper"), |
|
VisitWebpageTool(), |
|
PythonInterpreterTool(), |
|
SpeechToTextTool(), |
|
*reasoning_toolkit.tools, |
|
reverse_question, |
|
] |
|
self.agent = CodeAgent( |
|
tools=tools, |
|
model=model, |
|
) |
|
self.user_prompt = """ |
|
I will ask you a question. |
|
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. |
|
If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. |
|
If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. |
|
If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. |
|
|
|
You MUST use the following tools: |
|
- think, used before all other tool call and before the final answer |
|
- analyze, used after all other tool call and before the final answer |
|
|
|
Question: {question} |
|
|
|
Attached content: {content} |
|
""" |
|
logger.info("BasicAgent initialized.") |
|
|
|
def __call__( |
|
self, question: str, task_id: str, file_name: str, api_url: str |
|
) -> str: |
|
logger.info( |
|
f"Agent received question (first 50 chars): {question[:50]}..." |
|
) |
|
images = None |
|
content = "" |
|
|
|
if file_name: |
|
content = parse_file(task_id, file_name, api_url) |
|
if content: |
|
if isinstance( |
|
content, PIL.Image.Image |
|
): |
|
images = [content] |
|
else: |
|
logger.info(f"Question with content: {question}") |
|
|
|
prompt = self.user_prompt.format(question=question, content=content) |
|
|
|
answer = self.agent.run(prompt, images=images) |
|
answer = str(answer).replace("FINAL ANSWER:", "").strip() |
|
logger.info(f"Agent returning answer: {answer}") |
|
return answer |
|
|
|
|
|
if __name__ == "__main__": |
|
import requests |
|
|
|
api_url = "https://agents-course-unit4-scoring.hf.space" |
|
question_url = f"{api_url}/random-question" |
|
|
|
data = requests.get(question_url).json() |
|
agent = Agent() |
|
|
|
task_id = data["task_id"] |
|
question = data["question"] |
|
file_name = data["file_name"] |
|
logger.info( |
|
f"Task ID: {task_id}\nQuestion: {question}\nFile Name: {file_name}\n\n" |
|
) |
|
|
|
answer = agent(question, task_id, file_name, api_url) |
|
|