Spaces:
Restarting
Restarting
from smolagents import CodeAgent, LiteLLMModel, Tool, DuckDuckGoSearchTool, WikipediaSearchTool | |
from token_bucket import Limiter, MemoryStorage | |
from tenacity import retry, stop_after_attempt, wait_exponential | |
from langchain_community.document_loaders import ArxivLoader | |
from sentence_transformers import SentenceTransformer | |
from bs4 import BeautifulSoup | |
from datetime import datetime | |
import pandas as pd | |
import numpy as np | |
import requests | |
import asyncio | |
import whisper | |
import yaml | |
import os | |
import re | |
import json | |
from typing import Optional | |
# -------------------------- | |
# Core Tools from Previous Implementation | |
# -------------------------- | |
class VisitWebpageTool(Tool): | |
name = "visit_webpage" | |
description = "Visits a webpage and returns its content as markdown" | |
inputs = {'url': {'type': 'string', 'description': 'The URL to visit'}} | |
output_type = "string" | |
def forward(self, url: str) -> str: | |
try: | |
response = requests.get(url, timeout=30) | |
response.raise_for_status() | |
return markdownify(response.text).strip() | |
except Exception as e: | |
return f"Error fetching webpage: {str(e)}" | |
class DownloadTaskAttachmentTool(Tool): | |
name = "download_file" | |
description = "Downloads files from the task API" | |
inputs = {'task_id': {'type': 'string', 'description': 'The task ID to download'}} | |
output_type = "string" | |
def forward(self, task_id: str) -> str: | |
api_url = os.getenv("TASK_API_URL", "https://agents-course-unit4-scoring.hf.space") | |
file_url = f"{api_url}/files/{task_id}" | |
try: | |
response = requests.get(file_url, stream=True, timeout=30) | |
response.raise_for_status() | |
# File type detection | |
content_type = response.headers.get('Content-Type', '') | |
extension = self._get_extension(content_type) | |
os.makedirs("downloads", exist_ok=True) | |
file_path = f"downloads/{task_id}{extension}" | |
with open(file_path, "wb") as f: | |
for chunk in response.iter_content(chunk_size=8192): | |
f.write(chunk) | |
return file_path | |
except Exception as e: | |
raise RuntimeError(f"Download failed: {str(e)}") | |
def _get_extension(self, content_type: str) -> str: | |
type_map = { | |
'image/png': '.png', | |
'image/jpeg': '.jpg', | |
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': '.xlsx', | |
'audio/mpeg': '.mp3', | |
'application/pdf': '.pdf', | |
'text/x-python': '.py' | |
} | |
return type_map.get(content_type.split(';')[0], '.bin') | |
class ArxivSearchTool(Tool): | |
name = "arxiv_search" | |
description = "Searches academic papers on Arxiv" | |
inputs = {'query': {'type': 'string', 'description': 'Search query'}} | |
output_type = "string" | |
def forward(self, query: str) -> str: | |
try: | |
loader = ArxivLoader(query=query, load_max_docs=3) | |
docs = loader.load() | |
return "\n\n".join([ | |
f"Title: {doc.metadata['Title']}\n" | |
f"Authors: {doc.metadata['Authors']}\n" | |
f"Summary: {doc.page_content[:500]}..." | |
for doc in docs | |
]) | |
except Exception as e: | |
return f"Arxiv search failed: {str(e)}" | |
class SpeechToTextTool(Tool): | |
name = "speech_to_text" | |
description = "Converts audio files to text" | |
inputs = {'audio_path': {'type': 'string', 'description': 'Path to audio file'}} | |
output_type = "string" | |
def __init__(self): | |
self.model = whisper.load_model("base") | |
def forward(self, audio_path: str) -> str: | |
if not os.path.exists(audio_path): | |
return f"File not found: {audio_path}" | |
return self.model.transcribe(audio_path).get("text", "") | |
# -------------------------- | |
# Enhanced Tools with Validation | |
# -------------------------- | |
class ValidatedExcelReader(Tool): | |
name = "excel_reader" | |
description = "Reads and validates Excel files" | |
inputs = { | |
'file_path': {'type': 'string', 'description': 'Path to Excel file'}, | |
'schema': {'type': 'object', 'description': 'Validation schema', 'nullable': True} | |
} | |
output_type = "string" | |
def forward(self, file_path: str, schema: dict = None) -> str: | |
df = pd.read_excel(file_path) | |
if schema: | |
validation = ValidationPipeline().validate(df, schema) | |
if not validation['valid']: | |
raise ValueError(f"Data validation failed: {validation['errors']}") | |
return df.to_markdown() | |
# -------------------------- | |
# Integrated Universal Loader | |
# -------------------------- | |
class UniversalLoader(Tool): | |
name = "universal_loader" | |
description = "Loads various file types and web content using appropriate sub-tools." | |
inputs = { | |
'source': { | |
'type': 'string', | |
'description': 'Type of source to load (web/excel/audio/arxiv)' | |
}, | |
'task_id': { | |
'type': 'string', | |
'description': 'Task ID for attachments', | |
'nullable': True | |
} | |
} | |
output_type = "string" | |
def __init__(self): | |
self.loaders = { | |
'excel': ValidatedExcelReader(), | |
'audio': SpeechToTextTool(), | |
'arxiv': ArxivSearchTool(), | |
'web': VisitWebpageTool() | |
} | |
def forward(self, source: str, task_id: str = None) -> str: | |
try: | |
if source == "attachment": | |
file_path = DownloadTaskAttachmentTool()(task_id) | |
return self._load_by_type(file_path) | |
return self.loaders[source].forward(task_id) | |
except Exception as e: | |
return self._fallback(source, task_id) | |
def _load_by_type(self, file_path: str) -> str: | |
ext = file_path.split('.')[-1].lower() | |
loader_map = { | |
'xlsx': 'excel', | |
'mp3': 'audio', | |
'pdf': 'arxiv' | |
} | |
return self.loaders[loader_map.get(ext, 'web')].forward(file_path) | |
def _fallback(self, source: str, context: str) -> str: | |
return CrossVerifiedSearch()(f"{source} {context}") | |
# -------------------------- | |
# Validation Pipeline | |
# -------------------------- | |
class ValidationPipeline: | |
VALIDATORS = { | |
'numeric': { | |
'check': lambda x: pd.api.types.is_numeric_dtype(x), | |
'error': "Non-numeric value found in numeric field" | |
}, | |
'temporal': { | |
'check': lambda x: pd.api.types.is_datetime64_any_dtype(x), | |
'error': "Invalid date format detected" | |
}, | |
'categorical': { | |
'check': lambda x: x.isin(x.dropna().unique()), | |
'error': "Invalid category value detected" | |
} | |
} | |
def validate(self, data, schema: dict): | |
errors = [] | |
for field, config in schema.items(): | |
validator = self.VALIDATORS.get(config['type']) | |
if not validator['check'](data[field]): | |
errors.append(f"{field}: {validator['error']}") | |
return { | |
'valid': len(errors) == 0, | |
'errors': errors, | |
'confidence': 1.0 - (len(errors) / len(schema)) | |
} | |
# -------------------------- | |
# Tool Router | |
# -------------------------- | |
class ToolRouter: | |
def __init__(self): | |
self.encoder = SentenceTransformer('all-MiniLM-L6-v2') | |
self.domain_embeddings = { | |
'music': self.encoder.encode("music album release artist track"), | |
'sports': self.encoder.encode("athlete team score tournament"), | |
'science': self.encoder.encode("chemistry biology physics research") | |
} | |
self.ddg = DuckDuckGoSearchTool() | |
self.wiki = WikipediaSearchTool() | |
self.arxiv = ArxivSearchTool() | |
def forward(self, query: str, domain: str = None) -> str: | |
"""Smart search with domain prioritization""" | |
if domain == "academic": | |
return self.arxiv(query) | |
elif domain == "general": | |
return self.ddg(query) | |
elif domain == "encyclopedic": | |
return self.wiki(query) | |
# Fallback: Search all sources | |
results = { | |
"web": self.ddg(query), | |
"wikipedia": self.wiki(query), | |
"arxiv": self.arxiv(query) | |
} | |
return json.dumps(results) | |
def route(self, question: str): | |
query_embed = self.encoder.encode(question) | |
scores = { | |
domain: np.dot(query_embed, domain_embed) | |
for domain, domain_embed in self.domain_embeddings.items() | |
} | |
return max(scores, key=scores.get) | |
# -------------------------- | |
# Temporal Search | |
# -------------------------- | |
class HistoricalSearch: | |
def get_historical_content(self, url: str, target_date: str): | |
return requests.get( | |
f"http://archive.org/wayback/available?url={url}×tamp={target_date}" | |
).json() | |
# -------------------------- | |
# Enhanced Excel Reader | |
# -------------------------- | |
class EnhancedExcelReader(Tool): | |
def forward(self, path: str): | |
df = pd.read_excel(path) | |
validation = ValidationPipeline().validate(df, self._detect_schema(df)) | |
if not validation['valid']: | |
raise ValueError(f"Data validation failed: {validation['errors']}") | |
return df.to_markdown() | |
def _detect_schema(self, df: pd.DataFrame): | |
schema = {} | |
for col in df.columns: | |
dtype = 'categorical' | |
if pd.api.types.is_numeric_dtype(df[col]): | |
dtype = 'numeric' | |
elif pd.api.types.is_datetime64_any_dtype(df[col]): | |
dtype = 'temporal' | |
schema[col] = {'type': dtype} | |
return schema | |
# -------------------------- | |
# Cross-Verified Search | |
# -------------------------- | |
class CrossVerifiedSearch(Tool): | |
name = "cross_verified_search" | |
description = "Searches multiple sources and returns consensus results." | |
inputs = {'query': {'type': 'string', 'description': 'Search query'}} | |
output_type = "string" | |
SOURCES = [ | |
DuckDuckGoSearchTool(), | |
WikipediaSearchTool(), | |
ArxivSearchTool() | |
] | |
def __call__(self, query: str): | |
results = [] | |
for source in self.SOURCES: | |
try: | |
results.append(source(query)) | |
except Exception as e: | |
continue | |
return self._consensus(results) | |
def forward(self, query: str) -> str: | |
results = [] | |
for source in self.SOURCES: | |
try: | |
results.append(source(query)) | |
except Exception as e: | |
continue | |
return self._consensus(results) | |
def _consensus(self, results): | |
# Simple majority voting implementation | |
counts = {} | |
for result in results: | |
key = str(result)[:100] # Simple hash for demo | |
counts[key] = counts.get(key, 0) + 1 | |
return max(counts, key=counts.get) | |
# -------------------------- | |
# Main Agent Class (Integrated) | |
# -------------------------- | |
class MagAgent: | |
def __init__(self, rate_limiter: Optional[Limiter] = None): | |
self.rate_limiter = rate_limiter | |
self.model = LiteLLMModel( | |
model_id="gemini/gemini-1.5-flash", | |
api_key=os.environ.get("GEMINI_KEY"), | |
max_tokens=8192 | |
) | |
self.tools = [ | |
UniversalLoader(), | |
CrossVerifiedSearch(), # Replaces individual search tools | |
ValidatedExcelReader(), | |
VisitWebpageTool(), | |
DownloadTaskAttachmentTool(), | |
SpeechToTextTool(), | |
] | |
try: | |
with open("prompts.yaml") as f: | |
self.prompt_templates = yaml.safe_load(f) | |
except Exception as e: | |
self.prompt_templates = { | |
"base_prompt": "Default base prompt...", | |
"task_prompt": "Default task template: {question}" | |
} | |
self.agent = CodeAgent( | |
model=self.model, | |
tools=self.tools, | |
verbosity_level=2, | |
prompt_templates=self.prompt_templates, | |
max_steps=20, | |
add_base_tools=False | |
) | |
async def __call__(self, question: str, task_id: str) -> str: | |
try: | |
context = self._create_context(question, task_id) | |
result = await self._execute_agent(question, task_id) | |
return self._validate_and_format(result, context) | |
except Exception as e: | |
return self._handle_error(e, context) | |
# ... (keep other helper methods from previous implementation) | |
def _create_context(self, question: str, task_id: str) -> dict: | |
return { | |
"question": question, | |
"task_id": task_id, | |
"timestamp": datetime.now().isoformat(), | |
"validation_checks": [] | |
} | |
def _build_task_prompt(self, question: str, task_id: str) -> str: | |
"""Constructs task-specific prompts using templates""" | |
base_template = self.prompt_templates.get("base_prompt", "") | |
task_template = self.prompt_templates.get("task_prompt", "").format( | |
question=question, | |
task_id=task_id, | |
current_date=datetime.now().strftime("%Y-%m-%d") | |
) | |
return f"{base_template}\n\n{task_template}" | |
async def _execute_agent(self, question: str, task_id: str) -> str: | |
return await asyncio.to_thread( | |
self.agent.run, | |
task=self._build_task_prompt(question, task_id) | |
) | |
def _validate_and_format(self, result: str, context: dict) -> str: | |
"""Final validation and formatting pipeline""" | |
try: | |
# Basic result validation | |
if not result: | |
raise ValueError("Empty agent response") | |
# Type checking | |
if not isinstance(result, str): | |
raise TypeError(f"Expected string response, got {type(result)}") | |
# Length validation | |
if len(result) > 4096: | |
result = result[:4090] + "..." | |
# Record successful validation | |
context["validation_checks"].append({ | |
"type": "success", | |
"timestamp": datetime.now().isoformat() | |
}) | |
return result | |
except Exception as e: | |
return self._handle_error(e, context) | |
def _handle_error(self, error: Exception, context: dict) -> str: | |
"""Central error handling with context-aware formatting""" | |
error_type = error.__class__.__name__ | |
error_msg = str(error) | |
# Store error in validation context | |
context["validation_checks"].append({ | |
"type": "error", | |
"error_type": error_type, | |
"message": error_msg, | |
"timestamp": datetime.now().isoformat() | |
}) | |
# Format user-facing message | |
return f"AGENT ERROR: {error_type} - {error_msg}" | |