Test_Magus / agent.py
SergeyO7's picture
Update agent.py
117cd9e verified
raw
history blame
17 kB
from smolagents import CodeAgent, LiteLLMModel, tool, Tool, load_tool, DuckDuckGoSearchTool, WikipediaSearchTool #, HfApiModel, OpenAIServerModel
import asyncio
import os
import re
import pandas as pd
from typing import Optional
from token_bucket import Limiter, MemoryStorage
import yaml
from PIL import Image, ImageOps
import requests
from io import BytesIO
from markdownify import markdownify
import whisper
import time
import shutil
import traceback
@tool
def GoogleSearchTool(query: str) -> str:
"""Tool for performing Google searches using Custom Search JSON API
Args:
query (str): Search query string
Returns:
str: Formatted search results
"""
cse_id = os.environ.get("GOOGLE_CSE_ID")
if not api_key or not cse_id:
raise ValueError("GOOGLE_API_KEY and GOOGLE_CSE_ID must be set in environment variables.")
url = "https://www.googleapis.com/customsearch/v1"
params = {
"key": api_key,
"cx": cse_id,
"q": query,
"num": 5 # Number of results to return
}
try:
response = requests.get(url, params=params)
response.raise_for_status()
results = response.json().get("items", [])
return "\n".join([f"{item['title']}: {item['link']}" for item in results]) or "No results found."
except Exception as e:
return f"Error performing Google search: {str(e)}"
class VisitWebpageTool(Tool):
name = "visit_webpage"
description = "Visits a webpage at the given url and reads its content as a markdown string. Use this to browse webpages."
inputs = {'url': {'type': 'string', 'description': 'The url of the webpage to visit.'}}
output_type = "string"
def forward(self, url: str) -> str:
try:
response = requests.get(url, timeout=20)
response.raise_for_status()
markdown_content = markdownify(response.text).strip()
markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content)
from smolagents.utils import truncate_content
return truncate_content(markdown_content, 10000)
except requests.exceptions.Timeout:
return "The request timed out. Please try again later or check the URL."
except requests.exceptions.RequestException as e:
return f"Error fetching the webpage: {str(e)}"
except Exception as e:
return f"An unexpected error occurred: {str(e)}"
def __init__(self, *args, **kwargs):
self.is_initialized = False
class DownloadTaskAttachmentTool(Tool):
name = "download_file"
description = "Downloads the file attached to the task ID and returns the local file path. Supports Excel (.xlsx), image (.png, .jpg), audio (.mp3), PDF (.pdf), and Python (.py) files."
inputs = {'task_id': {'type': 'string', 'description': 'The task id to download attachment from.'}}
output_type = "string"
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
def __init__(self, rate_limiter: Optional[Limiter] = None, default_api_url: str = DEFAULT_API_URL, *args, **kwargs):
self.is_initialized = False
self.rate_limiter = rate_limiter
self.default_api_url = default_api_url
def forward(self, task_id: str) -> str:
file_url = f"{self.default_api_url}/files/{task_id}"
print(f"Downloading file for task ID {task_id} from {file_url}...")
try:
if self.rate_limiter:
while not self.rate_limiter.consume(1):
print(f"Rate limit reached for downloading file for task {task_id}. Waiting...")
time.sleep(60 / 15) # Assuming 15 RPM
response = requests.get(file_url, stream=True, timeout=15)
response.raise_for_status()
# Determine file extension based on Content-Type
content_type = response.headers.get('Content-Type', '').lower()
if 'image/png' in content_type:
extension = '.png'
elif 'image/jpeg' in content_type:
extension = '.jpg'
elif 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' in content_type:
extension = '.xlsx'
elif 'audio/mpeg' in content_type:
extension = '.mp3'
elif 'application/pdf' in content_type:
extension = '.pdf'
elif 'text/x-python' in content_type:
extension = '.py'
else:
return f"Error: Unsupported file type {content_type} for task {task_id}. Try using visit_webpage or web_search if the content is online."
local_file_path = f"downloads/{task_id}{extension}"
os.makedirs("downloads", exist_ok=True)
with open(local_file_path, "wb") as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
print(f"File downloaded successfully: {local_file_path}")
return local_file_path
except requests.exceptions.HTTPError as e:
if e.response.status_code == 429:
return f"Error: Rate limit exceeded for task {task_id}. Try again later."
return f"Error downloading file for task {task_id}: {str(e)}"
except requests.exceptions.RequestException as e:
return f"Error downloading file for task {task_id}: {str(e)}"
class SpeechToTextTool(Tool):
name = "speech_to_text"
description = (
"Converts an audio file to text using OpenAI Whisper."
)
inputs = {
"audio_path": {"type": "string", "description": "Path to audio file (.mp3, .wav)"},
}
output_type = "string"
def __init__(self):
super().__init__()
self.model = whisper.load_model("base")
def forward(self, audio_path: str) -> str:
if not os.path.exists(audio_path):
return f"Error: File not found at {audio_path}"
result = self.model.transcribe(audio_path)
return result.get("text", "")
class ExcelReaderTool(Tool):
name = "excel_reader"
description = """
This tool reads and processes Excel files (.xlsx, .xls).
It can extract data, calculate statistics, and perform data analysis on spreadsheets.
"""
inputs = {
"excel_path": {
"type": "string"
,
"description": "The path to the Excel file to read",
},
"sheet_name": {
"type": "string",
"description": "The name of the sheet to read (optional, defaults to first sheet)",
"nullable": True
}
}
output_type = "string"
def forward(self, excel_path: str, sheet_name: str = None) -> str:
"""
Reads and processes the given Excel file.
"""
try:
# Check if the file exists
if not os.path.exists(excel_path):
return f"Error: Excel file not found at {excel_path}"
import pandas as pd
# Read the Excel file
if sheet_name:
df = pd.read_excel(excel_path, sheet_name=sheet_name)
else:
df = pd.read_excel(excel_path)
# Get basic info about the data
info = {
"shape": df.shape,
"columns": list(df.columns),
"dtypes": df.dtypes.to_dict(),
"head": df.head(5).to_dict()
}
# Return formatted info
result = f"Excel file: {excel_path}\n"
result += f"Shape: {info['shape'][0]} rows × {info['shape'][1]} columns\n\n"
result += "Columns:\n"
for col in info['columns']:
result += f"- {col} ({info['dtypes'].get(col)})\n"
result += "\nPreview (first 5 rows):\n"
result += df.head(5).to_string()
return result
except Exception as e:
return f"Error reading Excel file: {str(e)}"
@tool
def PNG2FENTool(png_file: str) -> str:
"""Tool for converting a PNG file containing a chess board to a FEN position string.
Args:
png_file (str): The path to the PNG file.
Returns:
str: The FEN position string representing the chess board.
"""
# Raises:
# - FileNotFoundError:
# If the PNG file does not exist.
# - ValueError:
# If the PNG file cannot be processed or does not contain a valid chess board.
try:
# Open and preprocess image with modern Pillow
img = Image.open(png_file)
img = ImageOps.exif_transpose(img).convert("L")
# Use LANCZOS instead of ANTIALIAS
img = img.resize((img.width*2, img.height*2), Image.Resampling.LANCZOS)
# Save temp file for OCR
temp_path = "chess_temp.png"
img.save(temp_path)
# Perform OCR
import easyocr
reader = easyocr.Reader(['en'])
result = reader.readtext(png_file, detail=0)
fen_candidates = [text for text in result if validate_fen_format(text)]
if not fen_candidates:
raise ValueError("No valid FEN found in image")
return fen_candidates[0]
except Exception as e:
raise ValueError(f"OCR processing failed: {str(e)}")
# try:
# # Open the PNG file using PIL
# image = Image.open(png_file)
#
# # Use pytesseract to extract text from the image
# text = pytesseract.image_to_string(image)
#
# # Process the extracted text to get the FEN position string
# fen_position = process_text_to_fen(text)
#
# return fen_position
#
except FileNotFoundError:
raise FileNotFoundError("PNG file not found.")
#
# except Exception as e:
# raise ValueError("Error processing PNG file: " + str(e))
def process_text_to_fen(text):
"""
Processes the extracted text from the image to obtain the FEN position string.
Parameters:
- text: str
The extracted text from the image.
Returns:
- str:
The FEN position string representing the chess board.
Raises:
- ValueError:
If the extracted text does not contain a valid chess board.
"""
# Process the text to remove any unnecessary characters or spaces
processed_text = text.strip().replace("\n", "").replace(" ", "")
# Check if the processed text matches the expected format of a FEN position string
if not validate_fen_format(processed_text):
raise ValueError("Invalid chess board.")
return processed_text
def validate_fen_format(fen_string):
"""
Validates if a given string matches the format of a FEN (Forsyth–Edwards Notation) position string.
Parameters:
- fen_string: str
The string to be validated.
Returns:
- bool:
True if the string matches the FEN format, False otherwise.
"""
# FEN format: 8 sections separated by '/'
sections = fen_string.split("/")
if len(sections) != 8:
return False
# Check if each section contains valid characters
for section in sections:
if not validate_section(section):
return False
return True
def validate_section(section):
"""
Validates if a given section of a FEN (Forsyth–Edwards Notation) position string contains valid characters.
Parameters:
- section: str
The section to be validated.
Returns:
- bool:
True if the section contains valid characters, False otherwise.
"""
# Valid characters: digits 1-8 or letters 'r', 'n', 'b', 'q', 'k', 'p', 'R', 'N', 'B', 'Q', 'K', 'P'
valid_chars = set("12345678rnbqkpRNBQKP")
return all(char in valid_chars for char in section)
import chess
import chess.engine
class ChessEngineTool(Tool):
name = "chess_engine"
description = "Analyzes a chess position (FEN) with Stockfish and returns the best move."
inputs = {
"fen": {"type": "string", "description": "FEN string of the position."},
"time_limit": {"type": "number", "description": "Time in seconds for engine analysis.", "nullable": True}
}
output_type = "string"
def forward(self, fen: str, time_limit: float = 0.1) -> str:
# figure out where the binary actually is
sf_bin = shutil.which("stockfish") or "/usr/games/stockfish"
if not sf_bin:
raise RuntimeError(
f"Cannot find stockfish on PATH or at /usr/games/stockfish. "
"Did you install it in apt.txt or via apt-get?"
)
board = chess.Board(fen)
engine = chess.engine.SimpleEngine.popen_uci(sf_bin)
result = engine.play(board, chess.engine.Limit(time=time_limit))
engine.quit()
return board.san(result.move)
class PythonCodeReaderTool(Tool):
name = "read_python_code"
description = "Reads a Python (.py) file and returns its content as a string."
inputs = {
"file_path": {"type": "string", "description": "The path to the Python file to read"}
}
output_type = "string"
def forward(self, file_path: str) -> str:
try:
if not os.path.exists(file_path):
return f"Error: Python file not found at {file_path}"
with open(file_path, "r", encoding="utf-8") as file:
content = file.read()
return content
except Exception as e:
return f"Error reading Python file: {str(e)}"
class MagAgent:
def __init__(self, rate_limiter: Optional[Limiter] = None):
"""Initialize the MagAgent with search tools."""
self.rate_limiter = rate_limiter
print("Initializing MagAgent with search tools...")
# model = LiteLLMModel(
# model_id="gemini/gemini-2.0-flash-preview-image-generation",
# api_key= os.environ.get("GEMINI_KEY"),
# max_tokens=8192
# )
model = LiteLLMModel(
model_id="gemini/gemini-1.5-flash", # Use standard multimodal model
api_key=os.environ.get("GEMINI_KEY"),
max_tokens=8192,
api_base="https://generativelanguage.googleapis.com/v1beta" # Correct endpoint
)
# Load prompt templates
with open("prompts.yaml", 'r') as stream:
prompt_templates = yaml.safe_load(stream)
# Initialize rate limiter for DuckDuckGoSearchTool
search_rate_limiter = Limiter(rate=30/60, capacity=30, storage=MemoryStorage()) if not rate_limiter else rate_limiter
self.agent = CodeAgent(
model= model,
tools=[
DownloadTaskAttachmentTool(rate_limiter=rate_limiter),
# DuckDuckGoSearchTool(),
# WikipediaSearchTool(),
SpeechToTextTool(),
ExcelReaderTool(),
VisitWebpageTool(),
PythonCodeReaderTool(),
PNG2FENTool,
ChessEngineTool(),
# GoogleSearchTool,
# ImageAnalysisTool,
],
verbosity_level=2,
prompt_templates=prompt_templates,
add_base_tools=True,
max_steps=15
)
print("MagAgent initialized.")
async def __call__(self, question: str, task_id: str) -> str:
"""Process a question asynchronously using the MagAgent."""
print(f"MagAgent received question (first 50 chars): {question[:50]}... Task ID: {task_id}")
try:
if self.rate_limiter:
while not self.rate_limiter.consume(1):
print(f"Rate limit reached for task {task_id}. Waiting...")
await asyncio.sleep(60 / 15) # Assuming 15 RPM
# Include task_id in the task prompt to guide the agent
task = (
# f"Answer the following question accurately and concisely: \n"
f"{question} \n"
f"If the question references an attachment, use tool to download it with task_id: {task_id}\n"
# f"Return the answer as a string."
)
print(f"Calling agent.run for task {task_id}...")
response = await asyncio.to_thread(
self.agent.run,
task=task
)
print(f"Agent.run completed for task {task_id}.")
response = str(response)
if not response:
print(f"No answer found for task {task_id}.")
response = "No answer found."
print(f"MagAgent response: {response[:50]}...")
return response
except Exception as e:
error_msg = f"Error processing question for task {task_id}: {str(e)}. Check API key or network connectivity."
print(error_msg)
return error_msg