Test_Magus / agent.py
SergeyO7's picture
Update agent.py
b86353b verified
from smolagents import CodeAgent, LiteLLMModel, Tool, DuckDuckGoSearchTool, WikipediaSearchTool
from token_bucket import Limiter, MemoryStorage
from tenacity import retry, stop_after_attempt, wait_exponential
from langchain_community.document_loaders import ArxivLoader
from sentence_transformers import SentenceTransformer
from bs4 import BeautifulSoup
from datetime import datetime
import pandas as pd
import numpy as np
import requests
import asyncio
import whisper
import yaml
import os
import re
import json
from typing import Optional
# --------------------------
# Core Tools from Previous Implementation
# --------------------------
class VisitWebpageTool(Tool):
name = "visit_webpage"
description = "Visits a webpage and returns its content as markdown"
inputs = {'url': {'type': 'string', 'description': 'The URL to visit'}}
output_type = "string"
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10))
def forward(self, url: str) -> str:
try:
response = requests.get(url, timeout=30)
response.raise_for_status()
return markdownify(response.text).strip()
except Exception as e:
return f"Error fetching webpage: {str(e)}"
class DownloadTaskAttachmentTool(Tool):
name = "download_file"
description = "Downloads files from the task API"
inputs = {'task_id': {'type': 'string', 'description': 'The task ID to download'}}
output_type = "string"
def forward(self, task_id: str) -> str:
api_url = os.getenv("TASK_API_URL", "https://agents-course-unit4-scoring.hf.space")
file_url = f"{api_url}/files/{task_id}"
try:
response = requests.get(file_url, stream=True, timeout=30)
response.raise_for_status()
# File type detection
content_type = response.headers.get('Content-Type', '')
extension = self._get_extension(content_type)
os.makedirs("downloads", exist_ok=True)
file_path = f"downloads/{task_id}{extension}"
with open(file_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
return file_path
except Exception as e:
raise RuntimeError(f"Download failed: {str(e)}")
def _get_extension(self, content_type: str) -> str:
type_map = {
'image/png': '.png',
'image/jpeg': '.jpg',
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': '.xlsx',
'audio/mpeg': '.mp3',
'application/pdf': '.pdf',
'text/x-python': '.py'
}
return type_map.get(content_type.split(';')[0], '.bin')
class ArxivSearchTool(Tool):
name = "arxiv_search"
description = "Searches academic papers on Arxiv"
inputs = {'query': {'type': 'string', 'description': 'Search query'}}
output_type = "string"
def forward(self, query: str) -> str:
try:
loader = ArxivLoader(query=query, load_max_docs=3)
docs = loader.load()
return "\n\n".join([
f"Title: {doc.metadata['Title']}\n"
f"Authors: {doc.metadata['Authors']}\n"
f"Summary: {doc.page_content[:500]}..."
for doc in docs
])
except Exception as e:
return f"Arxiv search failed: {str(e)}"
class SpeechToTextTool(Tool):
name = "speech_to_text"
description = "Converts audio files to text"
inputs = {'audio_path': {'type': 'string', 'description': 'Path to audio file'}}
output_type = "string"
def __init__(self):
self.model = whisper.load_model("base")
def forward(self, audio_path: str) -> str:
if not os.path.exists(audio_path):
return f"File not found: {audio_path}"
return self.model.transcribe(audio_path).get("text", "")
# --------------------------
# Enhanced Tools with Validation
# --------------------------
class ValidatedExcelReader(Tool):
name = "excel_reader"
description = "Reads and validates Excel files"
inputs = {
'file_path': {'type': 'string', 'description': 'Path to Excel file'},
'schema': {'type': 'object', 'description': 'Validation schema', 'nullable': True}
}
output_type = "string"
def forward(self, file_path: str, schema: dict = None) -> str:
df = pd.read_excel(file_path)
if schema:
validation = ValidationPipeline().validate(df, schema)
if not validation['valid']:
raise ValueError(f"Data validation failed: {validation['errors']}")
return df.to_markdown()
# --------------------------
# Integrated Universal Loader
# --------------------------
class UniversalLoader(Tool):
name = "universal_loader"
description = "Loads various file types and web content using appropriate sub-tools."
inputs = {
'source': {
'type': 'string',
'description': 'Type of source to load (web/excel/audio/arxiv)'
},
'task_id': {
'type': 'string',
'description': 'Task ID for attachments',
'nullable': True
}
}
output_type = "string"
def __init__(self):
self.loaders = {
'excel': ValidatedExcelReader(),
'audio': SpeechToTextTool(),
'arxiv': ArxivSearchTool(),
'web': VisitWebpageTool()
}
def forward(self, source: str, task_id: str = None) -> str:
try:
if source == "attachment":
file_path = DownloadTaskAttachmentTool()(task_id)
return self._load_by_type(file_path)
return self.loaders[source].forward(task_id)
except Exception as e:
return self._fallback(source, task_id)
def _load_by_type(self, file_path: str) -> str:
ext = file_path.split('.')[-1].lower()
loader_map = {
'xlsx': 'excel',
'mp3': 'audio',
'pdf': 'arxiv'
}
return self.loaders[loader_map.get(ext, 'web')].forward(file_path)
def _fallback(self, source: str, context: str) -> str:
return CrossVerifiedSearch()(f"{source} {context}")
# --------------------------
# Validation Pipeline
# --------------------------
class ValidationPipeline:
VALIDATORS = {
'numeric': {
'check': lambda x: pd.api.types.is_numeric_dtype(x),
'error': "Non-numeric value found in numeric field"
},
'temporal': {
'check': lambda x: pd.api.types.is_datetime64_any_dtype(x),
'error': "Invalid date format detected"
},
'categorical': {
'check': lambda x: x.isin(x.dropna().unique()),
'error': "Invalid category value detected"
}
}
def validate(self, data, schema: dict):
errors = []
for field, config in schema.items():
validator = self.VALIDATORS.get(config['type'])
if not validator['check'](data[field]):
errors.append(f"{field}: {validator['error']}")
return {
'valid': len(errors) == 0,
'errors': errors,
'confidence': 1.0 - (len(errors) / len(schema))
}
# --------------------------
# Tool Router
# --------------------------
class ToolRouter:
def __init__(self):
self.encoder = SentenceTransformer('all-MiniLM-L6-v2')
self.domain_embeddings = {
'music': self.encoder.encode("music album release artist track"),
'sports': self.encoder.encode("athlete team score tournament"),
'science': self.encoder.encode("chemistry biology physics research")
}
self.ddg = DuckDuckGoSearchTool()
self.wiki = WikipediaSearchTool()
self.arxiv = ArxivSearchTool()
def forward(self, query: str, domain: str = None) -> str:
"""Smart search with domain prioritization"""
if domain == "academic":
return self.arxiv(query)
elif domain == "general":
return self.ddg(query)
elif domain == "encyclopedic":
return self.wiki(query)
# Fallback: Search all sources
results = {
"web": self.ddg(query),
"wikipedia": self.wiki(query),
"arxiv": self.arxiv(query)
}
return json.dumps(results)
def route(self, question: str):
query_embed = self.encoder.encode(question)
scores = {
domain: np.dot(query_embed, domain_embed)
for domain, domain_embed in self.domain_embeddings.items()
}
return max(scores, key=scores.get)
# --------------------------
# Temporal Search
# --------------------------
class HistoricalSearch:
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10))
def get_historical_content(self, url: str, target_date: str):
return requests.get(
f"http://archive.org/wayback/available?url={url}&timestamp={target_date}"
).json()
# --------------------------
# Enhanced Excel Reader
# --------------------------
class EnhancedExcelReader(Tool):
def forward(self, path: str):
df = pd.read_excel(path)
validation = ValidationPipeline().validate(df, self._detect_schema(df))
if not validation['valid']:
raise ValueError(f"Data validation failed: {validation['errors']}")
return df.to_markdown()
def _detect_schema(self, df: pd.DataFrame):
schema = {}
for col in df.columns:
dtype = 'categorical'
if pd.api.types.is_numeric_dtype(df[col]):
dtype = 'numeric'
elif pd.api.types.is_datetime64_any_dtype(df[col]):
dtype = 'temporal'
schema[col] = {'type': dtype}
return schema
# --------------------------
# Cross-Verified Search
# --------------------------
class CrossVerifiedSearch(Tool):
name = "cross_verified_search"
description = "Searches multiple sources and returns consensus results."
inputs = {'query': {'type': 'string', 'description': 'Search query'}}
output_type = "string"
SOURCES = [
DuckDuckGoSearchTool(),
WikipediaSearchTool(),
ArxivSearchTool()
]
def __call__(self, query: str):
results = []
for source in self.SOURCES:
try:
results.append(source(query))
except Exception as e:
continue
return self._consensus(results)
def forward(self, query: str) -> str:
results = []
for source in self.SOURCES:
try:
results.append(source(query))
except Exception as e:
continue
return self._consensus(results)
def _consensus(self, results):
# Simple majority voting implementation
counts = {}
for result in results:
key = str(result)[:100] # Simple hash for demo
counts[key] = counts.get(key, 0) + 1
return max(counts, key=counts.get)
# --------------------------
# Main Agent Class (Integrated)
# --------------------------
class MagAgent:
def __init__(self, rate_limiter: Optional[Limiter] = None):
self.rate_limiter = rate_limiter
self.model = LiteLLMModel(
model_id="gemini/gemini-1.5-flash",
api_key=os.environ.get("GEMINI_KEY"),
max_tokens=8192
)
self.tools = [
UniversalLoader(),
CrossVerifiedSearch(), # Replaces individual search tools
ValidatedExcelReader(),
VisitWebpageTool(),
DownloadTaskAttachmentTool(),
SpeechToTextTool(),
]
try:
with open("prompts.yaml") as f:
self.prompt_templates = yaml.safe_load(f)
except Exception as e:
self.prompt_templates = {
"base_prompt": "Default base prompt...",
"task_prompt": "Default task template: {question}"
}
self.agent = CodeAgent(
model=self.model,
tools=self.tools,
verbosity_level=2,
prompt_templates=self.prompt_templates,
max_steps=20,
add_base_tools=False
)
async def __call__(self, question: str, task_id: str) -> str:
try:
context = self._create_context(question, task_id)
result = await self._execute_agent(question, task_id)
return self._validate_and_format(result, context)
except Exception as e:
return self._handle_error(e, context)
# ... (keep other helper methods from previous implementation)
def _create_context(self, question: str, task_id: str) -> dict:
return {
"question": question,
"task_id": task_id,
"timestamp": datetime.now().isoformat(),
"validation_checks": []
}
def _build_task_prompt(self, question: str, task_id: str) -> str:
"""Constructs task-specific prompts using templates"""
base_template = self.prompt_templates.get("base_prompt", "")
task_template = self.prompt_templates.get("task_prompt", "").format(
question=question,
task_id=task_id,
current_date=datetime.now().strftime("%Y-%m-%d")
)
return f"{base_template}\n\n{task_template}"
async def _execute_agent(self, question: str, task_id: str) -> str:
return await asyncio.to_thread(
self.agent.run,
task=self._build_task_prompt(question, task_id)
)
def _validate_and_format(self, result: str, context: dict) -> str:
"""Final validation and formatting pipeline"""
try:
# Basic result validation
if not result:
raise ValueError("Empty agent response")
# Type checking
if not isinstance(result, str):
raise TypeError(f"Expected string response, got {type(result)}")
# Length validation
if len(result) > 4096:
result = result[:4090] + "..."
# Record successful validation
context["validation_checks"].append({
"type": "success",
"timestamp": datetime.now().isoformat()
})
return result
except Exception as e:
return self._handle_error(e, context)
def _handle_error(self, error: Exception, context: dict) -> str:
"""Central error handling with context-aware formatting"""
error_type = error.__class__.__name__
error_msg = str(error)
# Store error in validation context
context["validation_checks"].append({
"type": "error",
"error_type": error_type,
"message": error_msg,
"timestamp": datetime.now().isoformat()
})
# Format user-facing message
return f"AGENT ERROR: {error_type} - {error_msg}"