Spaces:
Sleeping
Sleeping
from smolagents import CodeAgent, LiteLLMModel, Tool | |
from token_bucket import Limiter, MemoryStorage | |
from tenacity import retry, stop_after_attempt, wait_exponential | |
from sentence_transformers import SentenceTransformer | |
from bs4 import BeautifulSoup | |
from datetime import datetime | |
import pandas as pd | |
import numpy as np | |
import requests | |
import asyncio | |
import whisper | |
import yaml | |
import os | |
import re | |
# -------------------------- | |
# Universal Data Loader | |
# -------------------------- | |
class UniversalLoader(Tool): | |
def __init__(self): | |
self.file_loaders = { | |
'xlsx': self._load_excel, | |
'csv': self._load_csv, | |
'png': self._load_image, | |
'mp3': self._load_audio | |
} | |
def forward(self, source: str, task_id: str = None): | |
try: | |
if source == "attachment": | |
file_path = self._download_attachment(task_id) | |
return self._load_by_extension(file_path) | |
elif source.startswith("http"): | |
return self._load_url(source) | |
except Exception as e: | |
return self._fallback_search(source, task_id) | |
def _download_attachment(self, task_id: str): | |
return DownloadTaskAttachmentTool()(task_id) | |
def _load_by_extension(self, path: str): | |
ext = path.split('.')[-1].lower() | |
loader = self.file_loaders.get(ext, self._load_text) | |
return loader(path) | |
def _load_excel(self, path: str): | |
return ExcelReaderTool().forward(path) | |
def _load_csv(self, path: str): | |
return pd.read_csv(path).to_markdown() | |
def _load_image(self, path: str): | |
return ImageAnalyzerTool().forward(path) | |
def _load_audio(self, path: str): | |
return SpeechToTextTool().forward(path) | |
def _fallback_search(self, query: str, context: str): | |
return CrossVerifiedSearch()(query, context) | |
# -------------------------- | |
# Validation Pipeline | |
# -------------------------- | |
class ValidationPipeline: | |
VALIDATORS = { | |
'numeric': { | |
'check': lambda x: pd.api.types.is_numeric_dtype(x), | |
'error': "Non-numeric value found in numeric field" | |
}, | |
'temporal': { | |
'check': lambda x: pd.api.types.is_datetime64_any_dtype(x), | |
'error': "Invalid date format detected" | |
}, | |
'categorical': { | |
'check': lambda x: x.isin(x.dropna().unique()), | |
'error': "Invalid category value detected" | |
} | |
} | |
def validate(self, data, schema: dict): | |
errors = [] | |
for field, config in schema.items(): | |
validator = self.VALIDATORS.get(config['type']) | |
if not validator['check'](data[field]): | |
errors.append(f"{field}: {validator['error']}") | |
return { | |
'valid': len(errors) == 0, | |
'errors': errors, | |
'confidence': 1.0 - (len(errors) / len(schema)) | |
} | |
# -------------------------- | |
# Tool Router | |
# -------------------------- | |
class ToolRouter: | |
def __init__(self): | |
self.encoder = SentenceTransformer('all-MiniLM-L6-v2') | |
self.domain_embeddings = { | |
'music': self.encoder.encode("music album release artist track"), | |
'sports': self.encoder.encode("athlete team score tournament"), | |
'science': self.encoder.encode("chemistry biology physics research") | |
} | |
def route(self, question: str): | |
query_embed = self.encoder.encode(question) | |
scores = { | |
domain: np.dot(query_embed, domain_embed) | |
for domain, domain_embed in self.domain_embeddings.items() | |
} | |
return max(scores, key=scores.get) | |
# -------------------------- | |
# Temporal Search | |
# -------------------------- | |
class HistoricalSearch: | |
def get_historical_content(self, url: str, target_date: str): | |
return requests.get( | |
f"http://archive.org/wayback/available?url={url}×tamp={target_date}" | |
).json() | |
# -------------------------- | |
# Enhanced Excel Reader | |
# -------------------------- | |
class EnhancedExcelReader(Tool): | |
def forward(self, path: str): | |
df = pd.read_excel(path) | |
validation = ValidationPipeline().validate(df, self._detect_schema(df)) | |
if not validation['valid']: | |
raise ValueError(f"Data validation failed: {validation['errors']}") | |
return df.to_markdown() | |
def _detect_schema(self, df: pd.DataFrame): | |
schema = {} | |
for col in df.columns: | |
dtype = 'categorical' | |
if pd.api.types.is_numeric_dtype(df[col]): | |
dtype = 'numeric' | |
elif pd.api.types.is_datetime64_any_dtype(df[col]): | |
dtype = 'temporal' | |
schema[col] = {'type': dtype} | |
return schema | |
# -------------------------- | |
# Cross-Verified Search | |
# -------------------------- | |
class CrossVerifiedSearch: | |
SOURCES = [ | |
DuckDuckGoSearchTool(), | |
WikipediaSearchTool(), | |
ArxivSearchTool() | |
] | |
def __call__(self, query: str): | |
results = [] | |
for source in self.SOURCES: | |
try: | |
results.append(source(query)) | |
except Exception as e: | |
continue | |
return self._consensus(results) | |
def _consensus(self, results): | |
# Simple majority voting implementation | |
counts = {} | |
for result in results: | |
key = str(result)[:100] # Simple hash for demo | |
counts[key] = counts.get(key, 0) + 1 | |
return max(counts, key=counts.get) | |
# -------------------------- | |
# Main Agent Class | |
# -------------------------- | |
class MagAgent: | |
def __init__(self, rate_limiter: Optional[Limiter] = None): | |
self.rate_limiter = rate_limiter | |
self.model = LiteLLMModel( | |
model_id="gemini/gemini-1.5-flash", | |
api_key=os.environ.get("GEMINI_KEY"), | |
max_tokens=8192 | |
) | |
self.tools = [ | |
UniversalLoader(), | |
EnhancedExcelReader(), | |
CrossVerifiedSearch(), | |
HistoricalSearch(), | |
ToolRouter() | |
] | |
with open("prompts.yaml") as f: | |
self.prompt_templates = yaml.safe_load(f) | |
self.agent = CodeAgent( | |
model=self.model, | |
tools=self.tools, | |
verbosity_level=2, | |
prompt_templates=self.prompt_templates, | |
max_steps=20 | |
) | |
async def __call__(self, question: str, task_id: str) -> str: | |
try: | |
context = { | |
"question": question, | |
"task_id": task_id, | |
"validation_checks": [] | |
} | |
result = await asyncio.to_thread( | |
self.agent.run, | |
task=self._build_task_prompt(question, task_id) | |
) | |
validated = self._validate_result(result, context) | |
return self._format_output(validated) | |
except Exception as e: | |
return self._handle_error(e, context) | |
def _build_task_prompt(self, question: str, task_id: str) -> str: | |
base_prompt = self.prompt_templates['base'] | |
domain = ToolRouter().route(question) | |
return f""" | |
{base_prompt} | |
**Domain Classification**: {domain} | |
**Required Validation**: {self._get_validation_requirements(domain)} | |
Question: {question} | |
{self._attachment_prompt(task_id)} | |
""" | |
def _validate_result(self, result: str, context: dict) -> dict: | |
validation_rules = { | |
'numeric': r'\d+', | |
'temporal': r'\d{4}-\d{2}-\d{2}', | |
'categorical': r'^[A-Za-z]+$' | |
} | |
validations = {} | |
for v_type, pattern in validation_rules.items(): | |
match = re.search(pattern, result) | |
validations[v_type] = bool(match) | |
confidence = sum(validations.values()) / len(validations) | |
context['validation_checks'] = validations | |
return { | |
'result': result, | |
'confidence': confidence, | |
'validations': validations | |
} | |
def _format_output(self, validated: dict) -> str: | |
if validated['confidence'] < 0.7: | |
return "Unable to verify answer with sufficient confidence" | |
return validated['result'] | |
def _handle_error(self, error: Exception, context: dict) -> str: | |
error_info = { | |
"type": type(error).__name__, | |
"message": str(error), | |
"context": context | |
} | |
return json.dumps(error_info) | |
def _get_validation_requirements(self, domain: str) -> str: | |
requirements = { | |
'music': "Verify release dates against multiple sources", | |
'sports': "Cross-check athlete statistics with official records", | |
'science': "Validate against peer-reviewed sources" | |
} | |
return requirements.get(domain, "Standard fact verification") | |
def _attachment_prompt(self, task_id: str) -> str: | |
if task_id: | |
return f"Attachment available with task_id: {task_id}" | |
return "No attachments provided" |