SergeyO7 commited on
Commit
5c64d65
·
verified ·
1 Parent(s): c7b45bd

Create agent.py

Browse files
Files changed (1) hide show
  1. agent.py +279 -0
agent.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from smolagents import CodeAgent, LiteLLMModel, Tool
2
+ from token_bucket import Limiter, MemoryStorage
3
+ from tenacity import retry, stop_after_attempt, wait_exponential
4
+ from sentence_transformers import SentenceTransformer
5
+ from bs4 import BeautifulSoup
6
+ from datetime import datetime
7
+ import pandas as pd
8
+ import numpy as np
9
+ import requests
10
+ import asyncio
11
+ import whisper
12
+ import yaml
13
+ import os
14
+ import re
15
+
16
+ # --------------------------
17
+ # Universal Data Loader
18
+ # --------------------------
19
+ class UniversalLoader(Tool):
20
+ def __init__(self):
21
+ self.file_loaders = {
22
+ 'xlsx': self._load_excel,
23
+ 'csv': self._load_csv,
24
+ 'png': self._load_image,
25
+ 'mp3': self._load_audio
26
+ }
27
+
28
+ def forward(self, source: str, task_id: str = None):
29
+ try:
30
+ if source == "attachment":
31
+ file_path = self._download_attachment(task_id)
32
+ return self._load_by_extension(file_path)
33
+ elif source.startswith("http"):
34
+ return self._load_url(source)
35
+ except Exception as e:
36
+ return self._fallback_search(source, task_id)
37
+
38
+ def _download_attachment(self, task_id: str):
39
+ return DownloadTaskAttachmentTool()(task_id)
40
+
41
+ def _load_by_extension(self, path: str):
42
+ ext = path.split('.')[-1].lower()
43
+ loader = self.file_loaders.get(ext, self._load_text)
44
+ return loader(path)
45
+
46
+ def _load_excel(self, path: str):
47
+ return ExcelReaderTool().forward(path)
48
+
49
+ def _load_csv(self, path: str):
50
+ return pd.read_csv(path).to_markdown()
51
+
52
+ def _load_image(self, path: str):
53
+ return ImageAnalyzerTool().forward(path)
54
+
55
+ def _load_audio(self, path: str):
56
+ return SpeechToTextTool().forward(path)
57
+
58
+ def _fallback_search(self, query: str, context: str):
59
+ return CrossVerifiedSearch()(query, context)
60
+
61
+ # --------------------------
62
+ # Validation Pipeline
63
+ # --------------------------
64
+ class ValidationPipeline:
65
+ VALIDATORS = {
66
+ 'numeric': {
67
+ 'check': lambda x: pd.api.types.is_numeric_dtype(x),
68
+ 'error': "Non-numeric value found in numeric field"
69
+ },
70
+ 'temporal': {
71
+ 'check': lambda x: pd.api.types.is_datetime64_any_dtype(x),
72
+ 'error': "Invalid date format detected"
73
+ },
74
+ 'categorical': {
75
+ 'check': lambda x: x.isin(x.dropna().unique()),
76
+ 'error': "Invalid category value detected"
77
+ }
78
+ }
79
+
80
+ def validate(self, data, schema: dict):
81
+ errors = []
82
+ for field, config in schema.items():
83
+ validator = self.VALIDATORS.get(config['type'])
84
+ if not validator['check'](data[field]):
85
+ errors.append(f"{field}: {validator['error']}")
86
+ return {
87
+ 'valid': len(errors) == 0,
88
+ 'errors': errors,
89
+ 'confidence': 1.0 - (len(errors) / len(schema))
90
+ }
91
+
92
+ # --------------------------
93
+ # Tool Router
94
+ # --------------------------
95
+ class ToolRouter:
96
+ def __init__(self):
97
+ self.encoder = SentenceTransformer('all-MiniLM-L6-v2')
98
+ self.domain_embeddings = {
99
+ 'music': self.encoder.encode("music album release artist track"),
100
+ 'sports': self.encoder.encode("athlete team score tournament"),
101
+ 'science': self.encoder.encode("chemistry biology physics research")
102
+ }
103
+
104
+ def route(self, question: str):
105
+ query_embed = self.encoder.encode(question)
106
+ scores = {
107
+ domain: np.dot(query_embed, domain_embed)
108
+ for domain, domain_embed in self.domain_embeddings.items()
109
+ }
110
+ return max(scores, key=scores.get)
111
+
112
+ # --------------------------
113
+ # Temporal Search
114
+ # --------------------------
115
+ class HistoricalSearch:
116
+ @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10))
117
+ def get_historical_content(self, url: str, target_date: str):
118
+ return requests.get(
119
+ f"http://archive.org/wayback/available?url={url}&timestamp={target_date}"
120
+ ).json()
121
+
122
+ # --------------------------
123
+ # Enhanced Excel Reader
124
+ # --------------------------
125
+ class EnhancedExcelReader(Tool):
126
+ def forward(self, path: str):
127
+ df = pd.read_excel(path)
128
+ validation = ValidationPipeline().validate(df, self._detect_schema(df))
129
+ if not validation['valid']:
130
+ raise ValueError(f"Data validation failed: {validation['errors']}")
131
+ return df.to_markdown()
132
+
133
+ def _detect_schema(self, df: pd.DataFrame):
134
+ schema = {}
135
+ for col in df.columns:
136
+ dtype = 'categorical'
137
+ if pd.api.types.is_numeric_dtype(df[col]):
138
+ dtype = 'numeric'
139
+ elif pd.api.types.is_datetime64_any_dtype(df[col]):
140
+ dtype = 'temporal'
141
+ schema[col] = {'type': dtype}
142
+ return schema
143
+
144
+ # --------------------------
145
+ # Cross-Verified Search
146
+ # --------------------------
147
+ class CrossVerifiedSearch:
148
+ SOURCES = [
149
+ DuckDuckGoSearchTool(),
150
+ WikipediaSearchTool(),
151
+ ArxivSearchTool()
152
+ ]
153
+
154
+ def __call__(self, query: str):
155
+ results = []
156
+ for source in self.SOURCES:
157
+ try:
158
+ results.append(source(query))
159
+ except Exception as e:
160
+ continue
161
+ return self._consensus(results)
162
+
163
+ def _consensus(self, results):
164
+ # Simple majority voting implementation
165
+ counts = {}
166
+ for result in results:
167
+ key = str(result)[:100] # Simple hash for demo
168
+ counts[key] = counts.get(key, 0) + 1
169
+ return max(counts, key=counts.get)
170
+
171
+ # --------------------------
172
+ # Main Agent Class
173
+ # --------------------------
174
+ class MagAgent:
175
+ def __init__(self, rate_limiter: Optional[Limiter] = None):
176
+ self.rate_limiter = rate_limiter
177
+ self.model = LiteLLMModel(
178
+ model_id="gemini/gemini-1.5-flash",
179
+ api_key=os.environ.get("GEMINI_KEY"),
180
+ max_tokens=8192
181
+ )
182
+
183
+ self.tools = [
184
+ UniversalLoader(),
185
+ EnhancedExcelReader(),
186
+ CrossVerifiedSearch(),
187
+ HistoricalSearch(),
188
+ ToolRouter()
189
+ ]
190
+
191
+ with open("prompts.yaml") as f:
192
+ self.prompt_templates = yaml.safe_load(f)
193
+
194
+ self.agent = CodeAgent(
195
+ model=self.model,
196
+ tools=self.tools,
197
+ verbosity_level=2,
198
+ prompt_templates=self.prompt_templates,
199
+ max_steps=20
200
+ )
201
+
202
+ async def __call__(self, question: str, task_id: str) -> str:
203
+ try:
204
+ context = {
205
+ "question": question,
206
+ "task_id": task_id,
207
+ "validation_checks": []
208
+ }
209
+
210
+ result = await asyncio.to_thread(
211
+ self.agent.run,
212
+ task=self._build_task_prompt(question, task_id)
213
+ )
214
+
215
+ validated = self._validate_result(result, context)
216
+ return self._format_output(validated)
217
+
218
+ except Exception as e:
219
+ return self._handle_error(e, context)
220
+
221
+ def _build_task_prompt(self, question: str, task_id: str) -> str:
222
+ base_prompt = self.prompt_templates['base']
223
+ domain = ToolRouter().route(question)
224
+ return f"""
225
+ {base_prompt}
226
+
227
+ **Domain Classification**: {domain}
228
+ **Required Validation**: {self._get_validation_requirements(domain)}
229
+
230
+ Question: {question}
231
+ {self._attachment_prompt(task_id)}
232
+ """
233
+
234
+ def _validate_result(self, result: str, context: dict) -> dict:
235
+ validation_rules = {
236
+ 'numeric': r'\d+',
237
+ 'temporal': r'\d{4}-\d{2}-\d{2}',
238
+ 'categorical': r'^[A-Za-z]+$'
239
+ }
240
+
241
+ validations = {}
242
+ for v_type, pattern in validation_rules.items():
243
+ match = re.search(pattern, result)
244
+ validations[v_type] = bool(match)
245
+
246
+ confidence = sum(validations.values()) / len(validations)
247
+ context['validation_checks'] = validations
248
+
249
+ return {
250
+ 'result': result,
251
+ 'confidence': confidence,
252
+ 'validations': validations
253
+ }
254
+
255
+ def _format_output(self, validated: dict) -> str:
256
+ if validated['confidence'] < 0.7:
257
+ return "Unable to verify answer with sufficient confidence"
258
+ return validated['result']
259
+
260
+ def _handle_error(self, error: Exception, context: dict) -> str:
261
+ error_info = {
262
+ "type": type(error).__name__,
263
+ "message": str(error),
264
+ "context": context
265
+ }
266
+ return json.dumps(error_info)
267
+
268
+ def _get_validation_requirements(self, domain: str) -> str:
269
+ requirements = {
270
+ 'music': "Verify release dates against multiple sources",
271
+ 'sports': "Cross-check athlete statistics with official records",
272
+ 'science': "Validate against peer-reviewed sources"
273
+ }
274
+ return requirements.get(domain, "Standard fact verification")
275
+
276
+ def _attachment_prompt(self, task_id: str) -> str:
277
+ if task_id:
278
+ return f"Attachment available with task_id: {task_id}"
279
+ return "No attachments provided"