Test_Magus / agent.py
SergeyO7's picture
Create agent.py
5c64d65 verified
raw
history blame
9.22 kB
from smolagents import CodeAgent, LiteLLMModel, Tool
from token_bucket import Limiter, MemoryStorage
from tenacity import retry, stop_after_attempt, wait_exponential
from sentence_transformers import SentenceTransformer
from bs4 import BeautifulSoup
from datetime import datetime
import pandas as pd
import numpy as np
import requests
import asyncio
import whisper
import yaml
import os
import re
# --------------------------
# Universal Data Loader
# --------------------------
class UniversalLoader(Tool):
def __init__(self):
self.file_loaders = {
'xlsx': self._load_excel,
'csv': self._load_csv,
'png': self._load_image,
'mp3': self._load_audio
}
def forward(self, source: str, task_id: str = None):
try:
if source == "attachment":
file_path = self._download_attachment(task_id)
return self._load_by_extension(file_path)
elif source.startswith("http"):
return self._load_url(source)
except Exception as e:
return self._fallback_search(source, task_id)
def _download_attachment(self, task_id: str):
return DownloadTaskAttachmentTool()(task_id)
def _load_by_extension(self, path: str):
ext = path.split('.')[-1].lower()
loader = self.file_loaders.get(ext, self._load_text)
return loader(path)
def _load_excel(self, path: str):
return ExcelReaderTool().forward(path)
def _load_csv(self, path: str):
return pd.read_csv(path).to_markdown()
def _load_image(self, path: str):
return ImageAnalyzerTool().forward(path)
def _load_audio(self, path: str):
return SpeechToTextTool().forward(path)
def _fallback_search(self, query: str, context: str):
return CrossVerifiedSearch()(query, context)
# --------------------------
# Validation Pipeline
# --------------------------
class ValidationPipeline:
VALIDATORS = {
'numeric': {
'check': lambda x: pd.api.types.is_numeric_dtype(x),
'error': "Non-numeric value found in numeric field"
},
'temporal': {
'check': lambda x: pd.api.types.is_datetime64_any_dtype(x),
'error': "Invalid date format detected"
},
'categorical': {
'check': lambda x: x.isin(x.dropna().unique()),
'error': "Invalid category value detected"
}
}
def validate(self, data, schema: dict):
errors = []
for field, config in schema.items():
validator = self.VALIDATORS.get(config['type'])
if not validator['check'](data[field]):
errors.append(f"{field}: {validator['error']}")
return {
'valid': len(errors) == 0,
'errors': errors,
'confidence': 1.0 - (len(errors) / len(schema))
}
# --------------------------
# Tool Router
# --------------------------
class ToolRouter:
def __init__(self):
self.encoder = SentenceTransformer('all-MiniLM-L6-v2')
self.domain_embeddings = {
'music': self.encoder.encode("music album release artist track"),
'sports': self.encoder.encode("athlete team score tournament"),
'science': self.encoder.encode("chemistry biology physics research")
}
def route(self, question: str):
query_embed = self.encoder.encode(question)
scores = {
domain: np.dot(query_embed, domain_embed)
for domain, domain_embed in self.domain_embeddings.items()
}
return max(scores, key=scores.get)
# --------------------------
# Temporal Search
# --------------------------
class HistoricalSearch:
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10))
def get_historical_content(self, url: str, target_date: str):
return requests.get(
f"http://archive.org/wayback/available?url={url}&timestamp={target_date}"
).json()
# --------------------------
# Enhanced Excel Reader
# --------------------------
class EnhancedExcelReader(Tool):
def forward(self, path: str):
df = pd.read_excel(path)
validation = ValidationPipeline().validate(df, self._detect_schema(df))
if not validation['valid']:
raise ValueError(f"Data validation failed: {validation['errors']}")
return df.to_markdown()
def _detect_schema(self, df: pd.DataFrame):
schema = {}
for col in df.columns:
dtype = 'categorical'
if pd.api.types.is_numeric_dtype(df[col]):
dtype = 'numeric'
elif pd.api.types.is_datetime64_any_dtype(df[col]):
dtype = 'temporal'
schema[col] = {'type': dtype}
return schema
# --------------------------
# Cross-Verified Search
# --------------------------
class CrossVerifiedSearch:
SOURCES = [
DuckDuckGoSearchTool(),
WikipediaSearchTool(),
ArxivSearchTool()
]
def __call__(self, query: str):
results = []
for source in self.SOURCES:
try:
results.append(source(query))
except Exception as e:
continue
return self._consensus(results)
def _consensus(self, results):
# Simple majority voting implementation
counts = {}
for result in results:
key = str(result)[:100] # Simple hash for demo
counts[key] = counts.get(key, 0) + 1
return max(counts, key=counts.get)
# --------------------------
# Main Agent Class
# --------------------------
class MagAgent:
def __init__(self, rate_limiter: Optional[Limiter] = None):
self.rate_limiter = rate_limiter
self.model = LiteLLMModel(
model_id="gemini/gemini-1.5-flash",
api_key=os.environ.get("GEMINI_KEY"),
max_tokens=8192
)
self.tools = [
UniversalLoader(),
EnhancedExcelReader(),
CrossVerifiedSearch(),
HistoricalSearch(),
ToolRouter()
]
with open("prompts.yaml") as f:
self.prompt_templates = yaml.safe_load(f)
self.agent = CodeAgent(
model=self.model,
tools=self.tools,
verbosity_level=2,
prompt_templates=self.prompt_templates,
max_steps=20
)
async def __call__(self, question: str, task_id: str) -> str:
try:
context = {
"question": question,
"task_id": task_id,
"validation_checks": []
}
result = await asyncio.to_thread(
self.agent.run,
task=self._build_task_prompt(question, task_id)
)
validated = self._validate_result(result, context)
return self._format_output(validated)
except Exception as e:
return self._handle_error(e, context)
def _build_task_prompt(self, question: str, task_id: str) -> str:
base_prompt = self.prompt_templates['base']
domain = ToolRouter().route(question)
return f"""
{base_prompt}
**Domain Classification**: {domain}
**Required Validation**: {self._get_validation_requirements(domain)}
Question: {question}
{self._attachment_prompt(task_id)}
"""
def _validate_result(self, result: str, context: dict) -> dict:
validation_rules = {
'numeric': r'\d+',
'temporal': r'\d{4}-\d{2}-\d{2}',
'categorical': r'^[A-Za-z]+$'
}
validations = {}
for v_type, pattern in validation_rules.items():
match = re.search(pattern, result)
validations[v_type] = bool(match)
confidence = sum(validations.values()) / len(validations)
context['validation_checks'] = validations
return {
'result': result,
'confidence': confidence,
'validations': validations
}
def _format_output(self, validated: dict) -> str:
if validated['confidence'] < 0.7:
return "Unable to verify answer with sufficient confidence"
return validated['result']
def _handle_error(self, error: Exception, context: dict) -> str:
error_info = {
"type": type(error).__name__,
"message": str(error),
"context": context
}
return json.dumps(error_info)
def _get_validation_requirements(self, domain: str) -> str:
requirements = {
'music': "Verify release dates against multiple sources",
'sports': "Cross-check athlete statistics with official records",
'science': "Validate against peer-reviewed sources"
}
return requirements.get(domain, "Standard fact verification")
def _attachment_prompt(self, task_id: str) -> str:
if task_id:
return f"Attachment available with task_id: {task_id}"
return "No attachments provided"