Spaces:
Running
Running
from smolagents import CodeAgent, LiteLLMModel, tool, Tool, load_tool, DuckDuckGoSearchTool, WikipediaSearchTool #, HfApiModel, OpenAIServerModel | |
import asyncio | |
import os | |
import re | |
import pandas as pd | |
from typing import Optional | |
from token_bucket import Limiter, MemoryStorage | |
import yaml | |
from PIL import Image, ImageOps | |
import requests | |
from io import BytesIO | |
from markdownify import markdownify | |
import whisper | |
import time | |
import shutil | |
import traceback | |
#@tool | |
#def GoogleSearchTool(query: str) -> str: | |
# """Tool for performing Google searches using Custom Search JSON API | |
# Args: | |
# query (str): Search query string | |
# Returns: | |
# str: Formatted search results | |
# """ | |
# api_key = os.environ.get("GOOGLE_API_KEY") | |
# cse_id = os.environ.get("GOOGLE_CSE_ID") | |
# if not api_key or not cse_id: | |
# raise ValueError("GOOGLE_API_KEY and GOOGLE_CSE_ID must be set in environment variables.") | |
# url = "https://www.googleapis.com/customsearch/v1" | |
# params = { | |
# "key": api_key, | |
# "cx": cse_id, | |
# "q": query, | |
# "num": 5 # Number of results to return | |
# } | |
# try: | |
# response = requests.get(url, params=params) | |
# response.raise_for_status() | |
# results = response.json().get("items", []) | |
# return "\n".join([f"{item['title']}: {item['link']}" for item in results]) or "No results found." | |
# except Exception as e: | |
# return f"Error performing Google search: {str(e)}" | |
from langchain_community.document_loaders import ArxivLoader | |
def search_arxiv(query: str) -> str: | |
"""Search Arxiv for a query and return maximum 3 result. | |
Args: | |
query: The search query. | |
Returns: | |
str: Formatted search results | |
""" | |
search_docs = ArxivLoader(query=query, load_max_docs=3).load() | |
formatted_search_docs = "\n\n---\n\n".join( | |
[ | |
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>' | |
for doc in search_docs | |
]) | |
return {"arxiv_results": formatted_search_docs} | |
class VisitWebpageTool(Tool): | |
name = "visit_webpage" | |
description = "Visits a webpage at the given url and reads its content as a markdown string. Use this to browse webpages." | |
inputs = {'url': {'type': 'string', 'description': 'The url of the webpage to visit.'}} | |
output_type = "string" | |
def forward(self, url: str) -> str: | |
try: | |
response = requests.get(url, timeout=50) | |
response.raise_for_status() | |
markdown_content = markdownify(response.text).strip() | |
markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content) | |
from smolagents.utils import truncate_content | |
return truncate_content(markdown_content, 10000) | |
except requests.exceptions.Timeout: | |
return "The request timed out. Please try again later or check the URL." | |
except requests.exceptions.RequestException as e: | |
return f"Error fetching the webpage: {str(e)}" | |
except Exception as e: | |
return f"An unexpected error occurred: {str(e)}" | |
def __init__(self, *args, **kwargs): | |
self.is_initialized = False | |
class DownloadTaskAttachmentTool(Tool): | |
name = "download_file" | |
description = "Downloads the file attached to the task ID and returns the local file path. Supports Excel (.xlsx), image (.png, .jpg), audio (.mp3), PDF (.pdf), and Python (.py) files." | |
inputs = {'task_id': {'type': 'string', 'description': 'The task id to download attachment from.'}} | |
output_type = "string" | |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" | |
def __init__(self, rate_limiter: Optional[Limiter] = None, default_api_url: str = DEFAULT_API_URL, *args, **kwargs): | |
self.is_initialized = False | |
self.rate_limiter = rate_limiter | |
self.default_api_url = default_api_url | |
def forward(self, task_id: str) -> str: | |
file_url = f"{self.default_api_url}/files/{task_id}" | |
print(f"Downloading file for task ID {task_id} from {file_url}...") | |
try: | |
if self.rate_limiter: | |
while not self.rate_limiter.consume(1): | |
print(f"Rate limit reached for downloading file for task {task_id}. Waiting...") | |
time.sleep(4) # Assuming 15 RPM | |
response = requests.get(file_url, stream=True, timeout=50) | |
response.raise_for_status() | |
# Determine file extension based on Content-Type | |
content_type = response.headers.get('Content-Type', '').lower() | |
if 'image/png' in content_type: | |
extension = '.png' | |
elif 'image/jpeg' in content_type: | |
extension = '.jpg' | |
elif 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' in content_type: | |
extension = '.xlsx' | |
elif 'audio/mpeg' in content_type: | |
extension = '.mp3' | |
elif 'application/pdf' in content_type: | |
extension = '.pdf' | |
elif 'text/x-python' in content_type: | |
extension = '.py' | |
else: | |
return f"Error: Unsupported file type {content_type} for task {task_id}. Try using visit_webpage or web_search if the content is online." | |
local_file_path = f"downloads/{task_id}{extension}" | |
os.makedirs("downloads", exist_ok=True) | |
with open(local_file_path, "wb") as file: | |
for chunk in response.iter_content(chunk_size=8192): | |
file.write(chunk) | |
print(f"File downloaded successfully: {local_file_path}") | |
return local_file_path | |
except requests.exceptions.HTTPError as e: | |
if e.response.status_code == 429: | |
return f"Error: Rate limit exceeded for task {task_id}. Try again later." | |
return f"Error downloading file for task {task_id}: {str(e)}" | |
except requests.exceptions.RequestException as e: | |
return f"Error downloading file for task {task_id}: {str(e)}" | |
class SpeechToTextTool(Tool): | |
name = "speech_to_text" | |
description = ( | |
"Converts an audio file to text using OpenAI Whisper." | |
) | |
inputs = { | |
"audio_path": {"type": "string", "description": "Path to audio file (.mp3, .wav)"}, | |
} | |
output_type = "string" | |
def __init__(self): | |
super().__init__() | |
self.model = whisper.load_model("base") | |
def forward(self, audio_path: str) -> str: | |
if not os.path.exists(audio_path): | |
return f"Error: File not found at {audio_path}" | |
result = self.model.transcribe(audio_path) | |
return result.get("text", "") | |
class ExcelReaderTool(Tool): | |
name = "excel_reader" | |
description = """ | |
This tool reads and processes Excel files (.xlsx, .xls). | |
It can extract data, calculate statistics, and perform data analysis on spreadsheets. | |
""" | |
inputs = { | |
"excel_path": { | |
"type": "string" | |
, | |
"description": "The path to the Excel file to read", | |
}, | |
"sheet_name": { | |
"type": "string", | |
"description": "The name of the sheet to read (optional, defaults to first sheet)", | |
"nullable": True | |
} | |
} | |
output_type = "string" | |
def forward(self, excel_path: str, sheet_name: str = None) -> str: | |
""" | |
Reads and processes the given Excel file. | |
""" | |
try: | |
# Check if the file exists | |
if not os.path.exists(excel_path): | |
return f"Error: Excel file not found at {excel_path}" | |
import pandas as pd | |
# Read the Excel file | |
if sheet_name: | |
df = pd.read_excel(excel_path, sheet_name=sheet_name) | |
else: | |
df = pd.read_excel(excel_path) | |
# Get basic info about the data | |
info = { | |
"shape": df.shape, | |
"columns": list(df.columns), | |
"dtypes": df.dtypes.to_dict(), | |
"head": df.head(5).to_dict() | |
} | |
# Return formatted info | |
result = f"Excel file: {excel_path}\n" | |
result += f"Shape: {info['shape'][0]} rows × {info['shape'][1]} columns\n\n" | |
result += "Columns:\n" | |
for col in info['columns']: | |
result += f"- {col} ({info['dtypes'].get(col)})\n" | |
result += "\nPreview (first 5 rows):\n" | |
result += df.head(5).to_string() | |
return result | |
except Exception as e: | |
return f"Error reading Excel file: {str(e)}" | |
class PythonCodeReaderTool(Tool): | |
name = "read_python_code" | |
description = "Reads a Python (.py) file and returns its content as a string." | |
inputs = { | |
"file_path": {"type": "string", "description": "The path to the Python file to read"} | |
} | |
output_type = "string" | |
def forward(self, file_path: str) -> str: | |
try: | |
if not os.path.exists(file_path): | |
return f"Error: Python file not found at {file_path}" | |
with open(file_path, "r", encoding="utf-8") as file: | |
content = file.read() | |
return content | |
except Exception as e: | |
return f"Error reading Python file: {str(e)}" | |
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type | |
#from smolagents.tools import DuckDuckGoSearchException # Replace with the actual exception if different | |
class RetryDuckDuckGoSearchTool(DuckDuckGoSearchTool): | |
def forward(self, query: str) -> str: | |
return super().forward(query) | |
class MagAgent: | |
def __init__(self, rate_limiter: Optional[Limiter] = None): | |
"""Initialize the MagAgent with search tools.""" | |
self.rate_limiter = rate_limiter | |
print("Initializing MagAgent with search tools...") | |
model = LiteLLMModel( | |
model_id="gemini/gemini-2.0-flash", | |
api_key= os.environ.get("GEMINI_KEY"), | |
max_tokens=8192 | |
) | |
# model = LiteLLMModel( | |
# model_id="gemini/gemini-1.5-flash", # Use standard multimodal model | |
# api_key=os.environ.get("GEMINI_KEY"), | |
# max_tokens=8192, | |
# api_base="https://generativelanguage.googleapis.com/v1beta" # Correct endpoint | |
# ) | |
# Load prompt templates | |
# with open("prompts.yaml", 'r') as stream: | |
# prompt_templates = yaml.safe_load(stream) | |
# Initialize rate limiter for DuckDuckGoSearchTool | |
search_rate_limiter = Limiter(rate=10/60, capacity=10, storage=MemoryStorage()) if not rate_limiter else rate_limiter | |
self.agent = CodeAgent( | |
model= model, | |
tools=[ | |
DownloadTaskAttachmentTool(rate_limiter=rate_limiter), | |
RetryDuckDuckGoSearchTool(), | |
WikipediaSearchTool(), | |
SpeechToTextTool(), | |
ExcelReaderTool(), | |
VisitWebpageTool(), | |
PythonCodeReaderTool(), | |
search_arxiv, | |
# PNG2FENTool, | |
# ChessEngineTool(), | |
# GoogleSearchTool, | |
# ImageAnalysisTool, | |
], | |
verbosity_level=2, | |
# prompt_templates=prompt_templates, | |
add_base_tools=False, | |
max_steps=20 | |
) | |
print("MagAgent initialized.") | |
async def __call__(self, question: str, task_id: str) -> str: | |
"""Process a question asynchronously using the MagAgent.""" | |
print(f"MagAgent received question (first 50 chars): {question[:50]}... Task ID: {task_id}") | |
try: | |
if self.rate_limiter: | |
while not self.rate_limiter.consume(1): | |
print(f"Rate limit reached for task {task_id}. Waiting...") | |
await asyncio.sleep(4) # Assuming 15 RPM | |
# Include task_id in the task prompt to guide the agent | |
task = ( | |
# f"Answer the following question accurately and concisely: \n" | |
"You are an advanced AI assistant tasked with answering questions from the GAIA benchmark accurately and concisely. Follow these guidelines:\n\n" | |
"1. **Question Parsing**:\n" | |
" - If the question includes direct speech or quoted text (e.g., \"Isn't that hot?\"), treat it as a precise query and preserve the quoted structure in your response.\n\n" | |
"2. **Handling Input Data**:\n" | |
f" - If the question references an attachment, use tool to download it with task_id: {task_id}\n" | |
" - When processing external data (e.g., YouTube transcripts, web searches), expect potential issues like missing punctuation, inconsistent formatting, or conversational text.\n" | |
" - If the input is ambiguous, prioritize extracting key information relevant to the question.\n\n" | |
"3. **Response Formatting**:\n" | |
" - Provide answers that are concise, accurate, and properly punctuated according to standard English grammar.\n" | |
" - Use quotation marks for direct quotes (e.g., \"Extreamly.\") and appropriate punctuation for lists, sentences, or clarifications.\n" | |
" - If asked about name of place or city, use full complete name without abbreviations (e.g. use Saint Petersburg instead of St.Petersburg). \n" | |
"4. **Error Handling**:\n" | |
" - If you cannot retrieve or process data (e.g., due to blocked requests), return a clear error message: \"Unable to retrieve data. Please refine the question or check external sources.\"\n\n" | |
f"Answer the following question: \n {question} \n" | |
# f"Return the answer as a string." | |
) | |
print(f"Calling agent.run for task {task_id}...") | |
response = await asyncio.to_thread( | |
self.agent.run, | |
task=task | |
) | |
print(f"Agent.run completed for task {task_id}.") | |
response = str(response) | |
if not response: | |
print(f"No answer found for task {task_id}.") | |
response = "No answer found." | |
print(f"MagAgent response: {response[:50]}...") | |
return response | |
except Exception as e: | |
error_msg = f"Error processing question for task {task_id}: {str(e)}. Check API key or network connectivity." | |
print(error_msg) | |
return error_msg |