from smolagents import CodeAgent, LiteLLMModel, Tool, DuckDuckGoSearchTool, WikipediaSearchTool from token_bucket import Limiter, MemoryStorage from tenacity import retry, stop_after_attempt, wait_exponential from langchain_community.document_loaders import ArxivLoader from sentence_transformers import SentenceTransformer from bs4 import BeautifulSoup from datetime import datetime import pandas as pd import numpy as np import requests import asyncio import whisper import yaml import os import re import json from typing import Optional # -------------------------- # Core Tools from Previous Implementation # -------------------------- class VisitWebpageTool(Tool): name = "visit_webpage" description = "Visits a webpage and returns its content as markdown" inputs = {'url': {'type': 'string', 'description': 'The URL to visit'}} output_type = "string" @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10)) def forward(self, url: str) -> str: try: response = requests.get(url, timeout=30) response.raise_for_status() return markdownify(response.text).strip() except Exception as e: return f"Error fetching webpage: {str(e)}" class DownloadTaskAttachmentTool(Tool): name = "download_file" description = "Downloads files from the task API" inputs = {'task_id': {'type': 'string', 'description': 'The task ID to download'}} output_type = "string" def forward(self, task_id: str) -> str: api_url = os.getenv("TASK_API_URL", "https://agents-course-unit4-scoring.hf.space") file_url = f"{api_url}/files/{task_id}" try: response = requests.get(file_url, stream=True, timeout=30) response.raise_for_status() # File type detection content_type = response.headers.get('Content-Type', '') extension = self._get_extension(content_type) os.makedirs("downloads", exist_ok=True) file_path = f"downloads/{task_id}{extension}" with open(file_path, "wb") as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) return file_path except Exception as e: raise RuntimeError(f"Download failed: {str(e)}") def _get_extension(self, content_type: str) -> str: type_map = { 'image/png': '.png', 'image/jpeg': '.jpg', 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': '.xlsx', 'audio/mpeg': '.mp3', 'application/pdf': '.pdf', 'text/x-python': '.py' } return type_map.get(content_type.split(';')[0], '.bin') class ArxivSearchTool(Tool): name = "arxiv_search" description = "Searches academic papers on Arxiv" inputs = {'query': {'type': 'string', 'description': 'Search query'}} output_type = "string" def forward(self, query: str) -> str: try: loader = ArxivLoader(query=query, load_max_docs=3) docs = loader.load() return "\n\n".join([ f"Title: {doc.metadata['Title']}\n" f"Authors: {doc.metadata['Authors']}\n" f"Summary: {doc.page_content[:500]}..." for doc in docs ]) except Exception as e: return f"Arxiv search failed: {str(e)}" class SpeechToTextTool(Tool): name = "speech_to_text" description = "Converts audio files to text" inputs = {'audio_path': {'type': 'string', 'description': 'Path to audio file'}} output_type = "string" def __init__(self): self.model = whisper.load_model("base") def forward(self, audio_path: str) -> str: if not os.path.exists(audio_path): return f"File not found: {audio_path}" return self.model.transcribe(audio_path).get("text", "") # -------------------------- # Enhanced Tools with Validation # -------------------------- class ValidatedExcelReader(Tool): name = "excel_reader" description = "Reads and validates Excel files" inputs = { 'file_path': {'type': 'string', 'description': 'Path to Excel file'}, 'schema': {'type': 'object', 'description': 'Validation schema', 'nullable': True} } output_type = "string" def forward(self, file_path: str, schema: dict = None) -> str: df = pd.read_excel(file_path) if schema: validation = ValidationPipeline().validate(df, schema) if not validation['valid']: raise ValueError(f"Data validation failed: {validation['errors']}") return df.to_markdown() # -------------------------- # Integrated Universal Loader # -------------------------- class UniversalLoader(Tool): name = "universal_loader" description = "Loads various file types and web content using appropriate sub-tools." inputs = { 'source': { 'type': 'string', 'description': 'Type of source to load (web/excel/audio/arxiv)' }, 'task_id': { 'type': 'string', 'description': 'Task ID for attachments', 'nullable': True } } output_type = "string" def __init__(self): self.loaders = { 'excel': ValidatedExcelReader(), 'audio': SpeechToTextTool(), 'arxiv': ArxivSearchTool(), 'web': VisitWebpageTool() } def forward(self, source: str, task_id: str = None) -> str: try: if source == "attachment": file_path = DownloadTaskAttachmentTool()(task_id) return self._load_by_type(file_path) return self.loaders[source].forward(task_id) except Exception as e: return self._fallback(source, task_id) def _load_by_type(self, file_path: str) -> str: ext = file_path.split('.')[-1].lower() loader_map = { 'xlsx': 'excel', 'mp3': 'audio', 'pdf': 'arxiv' } return self.loaders[loader_map.get(ext, 'web')].forward(file_path) def _fallback(self, source: str, context: str) -> str: return CrossVerifiedSearch()(f"{source} {context}") # -------------------------- # Validation Pipeline # -------------------------- class ValidationPipeline: VALIDATORS = { 'numeric': { 'check': lambda x: pd.api.types.is_numeric_dtype(x), 'error': "Non-numeric value found in numeric field" }, 'temporal': { 'check': lambda x: pd.api.types.is_datetime64_any_dtype(x), 'error': "Invalid date format detected" }, 'categorical': { 'check': lambda x: x.isin(x.dropna().unique()), 'error': "Invalid category value detected" } } def validate(self, data, schema: dict): errors = [] for field, config in schema.items(): validator = self.VALIDATORS.get(config['type']) if not validator['check'](data[field]): errors.append(f"{field}: {validator['error']}") return { 'valid': len(errors) == 0, 'errors': errors, 'confidence': 1.0 - (len(errors) / len(schema)) } # -------------------------- # Tool Router # -------------------------- class ToolRouter: def __init__(self): self.encoder = SentenceTransformer('all-MiniLM-L6-v2') self.domain_embeddings = { 'music': self.encoder.encode("music album release artist track"), 'sports': self.encoder.encode("athlete team score tournament"), 'science': self.encoder.encode("chemistry biology physics research") } self.ddg = DuckDuckGoSearchTool() self.wiki = WikipediaSearchTool() self.arxiv = ArxivSearchTool() def forward(self, query: str, domain: str = None) -> str: """Smart search with domain prioritization""" if domain == "academic": return self.arxiv(query) elif domain == "general": return self.ddg(query) elif domain == "encyclopedic": return self.wiki(query) # Fallback: Search all sources results = { "web": self.ddg(query), "wikipedia": self.wiki(query), "arxiv": self.arxiv(query) } return json.dumps(results) def route(self, question: str): query_embed = self.encoder.encode(question) scores = { domain: np.dot(query_embed, domain_embed) for domain, domain_embed in self.domain_embeddings.items() } return max(scores, key=scores.get) # -------------------------- # Temporal Search # -------------------------- class HistoricalSearch: @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10)) def get_historical_content(self, url: str, target_date: str): return requests.get( f"http://archive.org/wayback/available?url={url}×tamp={target_date}" ).json() # -------------------------- # Enhanced Excel Reader # -------------------------- class EnhancedExcelReader(Tool): def forward(self, path: str): df = pd.read_excel(path) validation = ValidationPipeline().validate(df, self._detect_schema(df)) if not validation['valid']: raise ValueError(f"Data validation failed: {validation['errors']}") return df.to_markdown() def _detect_schema(self, df: pd.DataFrame): schema = {} for col in df.columns: dtype = 'categorical' if pd.api.types.is_numeric_dtype(df[col]): dtype = 'numeric' elif pd.api.types.is_datetime64_any_dtype(df[col]): dtype = 'temporal' schema[col] = {'type': dtype} return schema # -------------------------- # Cross-Verified Search # -------------------------- class CrossVerifiedSearch(Tool): name = "cross_verified_search" description = "Searches multiple sources and returns consensus results." inputs = {'query': {'type': 'string', 'description': 'Search query'}} output_type = "string" SOURCES = [ DuckDuckGoSearchTool(), WikipediaSearchTool(), ArxivSearchTool() ] def __call__(self, query: str): results = [] for source in self.SOURCES: try: results.append(source(query)) except Exception as e: continue return self._consensus(results) def forward(self, query: str) -> str: results = [] for source in self.SOURCES: try: results.append(source(query)) except Exception as e: continue return self._consensus(results) def _consensus(self, results): # Simple majority voting implementation counts = {} for result in results: key = str(result)[:100] # Simple hash for demo counts[key] = counts.get(key, 0) + 1 return max(counts, key=counts.get) # -------------------------- # Main Agent Class (Integrated) # -------------------------- class MagAgent: def __init__(self, rate_limiter: Optional[Limiter] = None): self.rate_limiter = rate_limiter self.model = LiteLLMModel( model_id="gemini/gemini-1.5-flash", api_key=os.environ.get("GEMINI_KEY"), max_tokens=8192 ) self.tools = [ UniversalLoader(), CrossVerifiedSearch(), # Replaces individual search tools ValidatedExcelReader(), VisitWebpageTool(), DownloadTaskAttachmentTool(), SpeechToTextTool(), ] # Load prompts with required templates default_prompts = { "system_prompt": "Default system instructions...", "managed_agent": "Default subtask template...", "planning": "Default planning template...", "final_answer": "Default answer format: {answer}" } # try: # with open("prompts.yaml") as f: # self.prompt_templates = yaml.safe_load(f) # except Exception as e: # self.prompt_templates = { # "base_prompt": "Default base prompt...", # "task_prompt": "Default task template: {question}" # } # Load prompt templates with open("prompts.yaml", 'r') as stream: self.prompt_templates = yaml.safe_load(stream) self.agent = CodeAgent( model=self.model, tools=self.tools, verbosity_level=2, prompt_templates=self.prompt_templates, max_steps=20, add_base_tools=False ) async def __call__(self, question: str, task_id: str) -> str: try: context = self._create_context(question, task_id) result = await self._execute_agent(question, task_id) return self._validate_and_format(result, context) except Exception as e: return self._handle_error(e, context) # ... (keep other helper methods from previous implementation) def _create_context(self, question: str, task_id: str) -> dict: return { "question": question, "task_id": task_id, "timestamp": datetime.now().isoformat(), "validation_checks": [] } def _build_task_prompt(self, question: str, task_id: str) -> str: """Constructs task-specific prompts using templates""" base_template = self.prompt_templates.get("base_prompt", "") task_template = self.prompt_templates.get("task_prompt", "").format( question=question, task_id=task_id, current_date=datetime.now().strftime("%Y-%m-%d") ) return f"{base_template}\n\n{task_template}" async def _execute_agent(self, question: str, task_id: str) -> str: return await asyncio.to_thread( self.agent.run, task=self._build_task_prompt(question, task_id) ) def _validate_and_format(self, result: str, context: dict) -> str: """Final validation and formatting pipeline""" try: # Basic result validation if not result: raise ValueError("Empty agent response") # Type checking if not isinstance(result, str): raise TypeError(f"Expected string response, got {type(result)}") # Length validation if len(result) > 4096: result = result[:4090] + "..." # Record successful validation context["validation_checks"].append({ "type": "success", "timestamp": datetime.now().isoformat() }) return result except Exception as e: return self._handle_error(e, context) def _handle_error(self, error: Exception, context: dict) -> str: """Central error handling with context-aware formatting""" error_type = error.__class__.__name__ error_msg = str(error) # Store error in validation context context["validation_checks"].append({ "type": "error", "error_type": error_type, "message": error_msg, "timestamp": datetime.now().isoformat() }) # Format user-facing message return f"AGENT ERROR: {error_type} - {error_msg}"