SergeyO7 commited on
Commit
7418e84
·
verified ·
1 Parent(s): 64f8fd9

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +139 -3
agent.py CHANGED
@@ -1,4 +1,4 @@
1
- from smolagents import CodeAgent, LiteLLMModel, Tool
2
  from token_bucket import Limiter, MemoryStorage
3
  from tenacity import retry, stop_after_attempt, wait_exponential
4
  from langchain_community.document_loaders import ArxivLoader
@@ -165,6 +165,141 @@ class UniversalLoader(Tool):
165
  def _fallback(self, source: str, context: str) -> str:
166
  return CrossVerifiedSearch()(f"{source} {context}")
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  # --------------------------
169
  # Main Agent Class (Integrated)
170
  # --------------------------
@@ -180,11 +315,12 @@ class MagAgent:
180
 
181
  self.tools = [
182
  UniversalLoader(),
 
183
  ValidatedExcelReader(),
184
- ArxivSearchTool(),
185
  VisitWebpageTool(),
186
  DownloadTaskAttachmentTool(),
187
- SpeechToTextTool()
 
188
  ]
189
 
190
  with open("prompts.yaml") as f:
 
1
+ from smolagents import CodeAgent, LiteLLMModel, Tool, DuckDuckGoSearchTool, WikipediaSearchTool
2
  from token_bucket import Limiter, MemoryStorage
3
  from tenacity import retry, stop_after_attempt, wait_exponential
4
  from langchain_community.document_loaders import ArxivLoader
 
165
  def _fallback(self, source: str, context: str) -> str:
166
  return CrossVerifiedSearch()(f"{source} {context}")
167
 
168
+
169
+
170
+ # --------------------------
171
+ # Validation Pipeline
172
+ # --------------------------
173
+ class ValidationPipeline:
174
+ VALIDATORS = {
175
+ 'numeric': {
176
+ 'check': lambda x: pd.api.types.is_numeric_dtype(x),
177
+ 'error': "Non-numeric value found in numeric field"
178
+ },
179
+ 'temporal': {
180
+ 'check': lambda x: pd.api.types.is_datetime64_any_dtype(x),
181
+ 'error': "Invalid date format detected"
182
+ },
183
+ 'categorical': {
184
+ 'check': lambda x: x.isin(x.dropna().unique()),
185
+ 'error': "Invalid category value detected"
186
+ }
187
+ }
188
+
189
+ def validate(self, data, schema: dict):
190
+ errors = []
191
+ for field, config in schema.items():
192
+ validator = self.VALIDATORS.get(config['type'])
193
+ if not validator['check'](data[field]):
194
+ errors.append(f"{field}: {validator['error']}")
195
+ return {
196
+ 'valid': len(errors) == 0,
197
+ 'errors': errors,
198
+ 'confidence': 1.0 - (len(errors) / len(schema))
199
+ }
200
+
201
+ # --------------------------
202
+ # Tool Router
203
+ # --------------------------
204
+ class ToolRouter:
205
+ def __init__(self):
206
+ self.encoder = SentenceTransformer('all-MiniLM-L6-v2')
207
+ self.domain_embeddings = {
208
+ 'music': self.encoder.encode("music album release artist track"),
209
+ 'sports': self.encoder.encode("athlete team score tournament"),
210
+ 'science': self.encoder.encode("chemistry biology physics research")
211
+ }
212
+ self.ddg = DuckDuckGoSearchTool()
213
+ self.wiki = WikipediaSearchTool()
214
+ self.arxiv = ArxivSearchTool()
215
+
216
+ def forward(self, query: str, domain: str = None) -> str:
217
+ """Smart search with domain prioritization"""
218
+ if domain == "academic":
219
+ return self.arxiv(query)
220
+ elif domain == "general":
221
+ return self.ddg(query)
222
+ elif domain == "encyclopedic":
223
+ return self.wiki(query)
224
+
225
+ # Fallback: Search all sources
226
+ results = {
227
+ "web": self.ddg(query),
228
+ "wikipedia": self.wiki(query),
229
+ "arxiv": self.arxiv(query)
230
+ }
231
+ return json.dumps(results)
232
+
233
+ def route(self, question: str):
234
+ query_embed = self.encoder.encode(question)
235
+ scores = {
236
+ domain: np.dot(query_embed, domain_embed)
237
+ for domain, domain_embed in self.domain_embeddings.items()
238
+ }
239
+ return max(scores, key=scores.get)
240
+
241
+ # --------------------------
242
+ # Temporal Search
243
+ # --------------------------
244
+ class HistoricalSearch:
245
+ @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10))
246
+ def get_historical_content(self, url: str, target_date: str):
247
+ return requests.get(
248
+ f"http://archive.org/wayback/available?url={url}&timestamp={target_date}"
249
+ ).json()
250
+
251
+ # --------------------------
252
+ # Enhanced Excel Reader
253
+ # --------------------------
254
+ class EnhancedExcelReader(Tool):
255
+ def forward(self, path: str):
256
+ df = pd.read_excel(path)
257
+ validation = ValidationPipeline().validate(df, self._detect_schema(df))
258
+ if not validation['valid']:
259
+ raise ValueError(f"Data validation failed: {validation['errors']}")
260
+ return df.to_markdown()
261
+
262
+ def _detect_schema(self, df: pd.DataFrame):
263
+ schema = {}
264
+ for col in df.columns:
265
+ dtype = 'categorical'
266
+ if pd.api.types.is_numeric_dtype(df[col]):
267
+ dtype = 'numeric'
268
+ elif pd.api.types.is_datetime64_any_dtype(df[col]):
269
+ dtype = 'temporal'
270
+ schema[col] = {'type': dtype}
271
+ return schema
272
+
273
+ # --------------------------
274
+ # Cross-Verified Search
275
+ # --------------------------
276
+ class CrossVerifiedSearch:
277
+ SOURCES = [
278
+ DuckDuckGoSearchTool(),
279
+ WikipediaSearchTool(),
280
+ ArxivSearchTool()
281
+ ]
282
+
283
+ def __call__(self, query: str):
284
+ results = []
285
+ for source in self.SOURCES:
286
+ try:
287
+ results.append(source(query))
288
+ except Exception as e:
289
+ continue
290
+ return self._consensus(results)
291
+
292
+ def _consensus(self, results):
293
+ # Simple majority voting implementation
294
+ counts = {}
295
+ for result in results:
296
+ key = str(result)[:100] # Simple hash for demo
297
+ counts[key] = counts.get(key, 0) + 1
298
+ return max(counts, key=counts.get)
299
+
300
+
301
+
302
+
303
  # --------------------------
304
  # Main Agent Class (Integrated)
305
  # --------------------------
 
315
 
316
  self.tools = [
317
  UniversalLoader(),
318
+ EnhancedSearchTool(), # Replaces individual search tools
319
  ValidatedExcelReader(),
 
320
  VisitWebpageTool(),
321
  DownloadTaskAttachmentTool(),
322
+ SpeechToTextTool(),
323
+ CrossVerifiedSearch()
324
  ]
325
 
326
  with open("prompts.yaml") as f: