SergeyO7 commited on
Commit
1c22788
·
verified ·
1 Parent(s): 7192624

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +154 -207
agent.py CHANGED
@@ -1,6 +1,7 @@
1
- from smolagents import CodeAgent, LiteLLMModel, Tool, DuckDuckGoSearchTool, WikipediaSearchTool
2
  from token_bucket import Limiter, MemoryStorage
3
  from tenacity import retry, stop_after_attempt, wait_exponential
 
4
  from sentence_transformers import SentenceTransformer
5
  from bs4 import BeautifulSoup
6
  from datetime import datetime
@@ -12,165 +13,161 @@ import whisper
12
  import yaml
13
  import os
14
  import re
 
15
 
16
  # --------------------------
17
- # Universal Data Loader
18
  # --------------------------
19
- class UniversalLoader(Tool):
20
- def __init__(self):
21
- self.file_loaders = {
22
- 'xlsx': self._load_excel,
23
- 'csv': self._load_csv,
24
- 'png': self._load_image,
25
- 'mp3': self._load_audio
26
- }
27
-
28
- def forward(self, source: str, task_id: str = None):
29
  try:
30
- if source == "attachment":
31
- file_path = self._download_attachment(task_id)
32
- return self._load_by_extension(file_path)
33
- elif source.startswith("http"):
34
- return self._load_url(source)
35
  except Exception as e:
36
- return self._fallback_search(source, task_id)
37
 
38
- def _download_attachment(self, task_id: str):
39
- return DownloadTaskAttachmentTool()(task_id)
 
 
 
40
 
41
- def _load_by_extension(self, path: str):
42
- ext = path.split('.')[-1].lower()
43
- loader = self.file_loaders.get(ext, self._load_text)
44
- return loader(path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- def _load_excel(self, path: str):
47
- return ExcelReaderTool().forward(path)
 
 
 
48
 
49
- def _load_csv(self, path: str):
50
- return pd.read_csv(path).to_markdown()
 
 
 
 
 
 
 
 
 
 
51
 
52
- def _load_image(self, path: str):
53
- return ImageAnalyzerTool().forward(path)
 
 
 
54
 
55
- def _load_audio(self, path: str):
56
- return SpeechToTextTool().forward(path)
57
 
58
- def _fallback_search(self, query: str, context: str):
59
- return CrossVerifiedSearch()(query, context)
 
 
60
 
61
  # --------------------------
62
- # Validation Pipeline
63
  # --------------------------
64
- class ValidationPipeline:
65
- VALIDATORS = {
66
- 'numeric': {
67
- 'check': lambda x: pd.api.types.is_numeric_dtype(x),
68
- 'error': "Non-numeric value found in numeric field"
69
- },
70
- 'temporal': {
71
- 'check': lambda x: pd.api.types.is_datetime64_any_dtype(x),
72
- 'error': "Invalid date format detected"
73
- },
74
- 'categorical': {
75
- 'check': lambda x: x.isin(x.dropna().unique()),
76
- 'error': "Invalid category value detected"
77
- }
78
  }
 
79
 
80
- def validate(self, data, schema: dict):
81
- errors = []
82
- for field, config in schema.items():
83
- validator = self.VALIDATORS.get(config['type'])
84
- if not validator['check'](data[field]):
85
- errors.append(f"{field}: {validator['error']}")
86
- return {
87
- 'valid': len(errors) == 0,
88
- 'errors': errors,
89
- 'confidence': 1.0 - (len(errors) / len(schema))
90
- }
91
 
92
  # --------------------------
93
- # Tool Router
94
  # --------------------------
95
- class ToolRouter:
 
96
  def __init__(self):
97
- self.encoder = SentenceTransformer('all-MiniLM-L6-v2')
98
- self.domain_embeddings = {
99
- 'music': self.encoder.encode("music album release artist track"),
100
- 'sports': self.encoder.encode("athlete team score tournament"),
101
- 'science': self.encoder.encode("chemistry biology physics research")
102
  }
103
 
104
- def route(self, question: str):
105
- query_embed = self.encoder.encode(question)
106
- scores = {
107
- domain: np.dot(query_embed, domain_embed)
108
- for domain, domain_embed in self.domain_embeddings.items()
 
 
 
 
 
 
 
 
 
 
109
  }
110
- return max(scores, key=scores.get)
111
-
112
- # --------------------------
113
- # Temporal Search
114
- # --------------------------
115
- class HistoricalSearch:
116
- @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10))
117
- def get_historical_content(self, url: str, target_date: str):
118
- return requests.get(
119
- f"http://archive.org/wayback/available?url={url}&timestamp={target_date}"
120
- ).json()
121
-
122
- # --------------------------
123
- # Enhanced Excel Reader
124
- # --------------------------
125
- class EnhancedExcelReader(Tool):
126
- def forward(self, path: str):
127
- df = pd.read_excel(path)
128
- validation = ValidationPipeline().validate(df, self._detect_schema(df))
129
- if not validation['valid']:
130
- raise ValueError(f"Data validation failed: {validation['errors']}")
131
- return df.to_markdown()
132
 
133
- def _detect_schema(self, df: pd.DataFrame):
134
- schema = {}
135
- for col in df.columns:
136
- dtype = 'categorical'
137
- if pd.api.types.is_numeric_dtype(df[col]):
138
- dtype = 'numeric'
139
- elif pd.api.types.is_datetime64_any_dtype(df[col]):
140
- dtype = 'temporal'
141
- schema[col] = {'type': dtype}
142
- return schema
143
 
144
  # --------------------------
145
- # Cross-Verified Search
146
  # --------------------------
147
- class CrossVerifiedSearch:
148
- SOURCES = [
149
- DuckDuckGoSearchTool(),
150
- WikipediaSearchTool(),
151
- ArxivSearchTool()
152
- ]
153
-
154
- def __call__(self, query: str):
155
- results = []
156
- for source in self.SOURCES:
157
- try:
158
- results.append(source(query))
159
- except Exception as e:
160
- continue
161
- return self._consensus(results)
162
-
163
- def _consensus(self, results):
164
- # Simple majority voting implementation
165
- counts = {}
166
- for result in results:
167
- key = str(result)[:100] # Simple hash for demo
168
- counts[key] = counts.get(key, 0) + 1
169
- return max(counts, key=counts.get)
170
 
171
- # --------------------------
172
- # Main Agent Class
173
- # --------------------------
174
  class MagAgent:
175
  def __init__(self, rate_limiter: Optional[Limiter] = None):
176
  self.rate_limiter = rate_limiter
@@ -182,98 +179,48 @@ class MagAgent:
182
 
183
  self.tools = [
184
  UniversalLoader(),
185
- EnhancedExcelReader(),
186
- CrossVerifiedSearch(),
187
- HistoricalSearch(),
188
- ToolRouter()
 
189
  ]
190
 
191
- # with open("prompts.yaml") as f:
192
- # self.prompt_templates = yaml.safe_load(f)
193
 
194
  self.agent = CodeAgent(
195
  model=self.model,
196
  tools=self.tools,
197
  verbosity_level=2,
198
- # prompt_templates=self.prompt_templates,
199
- max_steps=20
 
200
  )
201
 
202
  async def __call__(self, question: str, task_id: str) -> str:
203
  try:
204
- context = {
205
- "question": question,
206
- "task_id": task_id,
207
- "validation_checks": []
208
- }
209
-
210
- result = await asyncio.to_thread(
211
- self.agent.run,
212
- task=self._build_task_prompt(question, task_id)
213
- )
214
-
215
- validated = self._validate_result(result, context)
216
- return self._format_output(validated)
217
-
218
  except Exception as e:
219
  return self._handle_error(e, context)
220
 
221
- def _build_task_prompt(self, question: str, task_id: str) -> str:
222
- base_prompt = self.prompt_templates['base']
223
- domain = ToolRouter().route(question)
224
- return f"""
225
- {base_prompt}
226
-
227
- **Domain Classification**: {domain}
228
- **Required Validation**: {self._get_validation_requirements(domain)}
229
-
230
- Question: {question}
231
- {self._attachment_prompt(task_id)}
232
- """
233
 
234
- def _validate_result(self, result: str, context: dict) -> dict:
235
- validation_rules = {
236
- 'numeric': r'\d+',
237
- 'temporal': r'\d{4}-\d{2}-\d{2}',
238
- 'categorical': r'^[A-Za-z]+$'
239
- }
240
-
241
- validations = {}
242
- for v_type, pattern in validation_rules.items():
243
- match = re.search(pattern, result)
244
- validations[v_type] = bool(match)
245
-
246
- confidence = sum(validations.values()) / len(validations)
247
- context['validation_checks'] = validations
248
-
249
  return {
250
- 'result': result,
251
- 'confidence': confidence,
252
- 'validations': validations
253
- }
254
-
255
- def _format_output(self, validated: dict) -> str:
256
- if validated['confidence'] < 0.7:
257
- return "Unable to verify answer with sufficient confidence"
258
- return validated['result']
259
-
260
- def _handle_error(self, error: Exception, context: dict) -> str:
261
- error_info = {
262
- "type": type(error).__name__,
263
- "message": str(error),
264
- "context": context
265
  }
266
- return json.dumps(error_info)
267
 
268
- def _get_validation_requirements(self, domain: str) -> str:
269
- requirements = {
270
- 'music': "Verify release dates against multiple sources",
271
- 'sports': "Cross-check athlete statistics with official records",
272
- 'science': "Validate against peer-reviewed sources"
273
- }
274
- return requirements.get(domain, "Standard fact verification")
275
 
276
- def _attachment_prompt(self, task_id: str) -> str:
277
- if task_id:
278
- return f"Attachment available with task_id: {task_id}"
279
- return "No attachments provided"
 
1
+ from smolagents import CodeAgent, LiteLLMModel, Tool
2
  from token_bucket import Limiter, MemoryStorage
3
  from tenacity import retry, stop_after_attempt, wait_exponential
4
+ from langchain_community.document_loaders import ArxivLoader
5
  from sentence_transformers import SentenceTransformer
6
  from bs4 import BeautifulSoup
7
  from datetime import datetime
 
13
  import yaml
14
  import os
15
  import re
16
+ import json
17
 
18
  # --------------------------
19
+ # Core Tools from Previous Implementation
20
  # --------------------------
21
+
22
+ class VisitWebpageTool(Tool):
23
+ name = "visit_webpage"
24
+ description = "Visits a webpage and returns its content as markdown"
25
+ inputs = {'url': {'type': 'string', 'description': 'The URL to visit'}}
26
+ output_type = "string"
27
+
28
+ @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10))
29
+ def forward(self, url: str) -> str:
 
30
  try:
31
+ response = requests.get(url, timeout=30)
32
+ response.raise_for_status()
33
+ return markdownify(response.text).strip()
 
 
34
  except Exception as e:
35
+ return f"Error fetching webpage: {str(e)}"
36
 
37
+ class DownloadTaskAttachmentTool(Tool):
38
+ name = "download_file"
39
+ description = "Downloads files from the task API"
40
+ inputs = {'task_id': {'type': 'string', 'description': 'The task ID to download'}}
41
+ output_type = "string"
42
 
43
+ def forward(self, task_id: str) -> str:
44
+ api_url = os.getenv("TASK_API_URL", "https://agents-course-unit4-scoring.hf.space")
45
+ file_url = f"{api_url}/files/{task_id}"
46
+
47
+ try:
48
+ response = requests.get(file_url, stream=True, timeout=30)
49
+ response.raise_for_status()
50
+
51
+ # File type detection
52
+ content_type = response.headers.get('Content-Type', '')
53
+ extension = self._get_extension(content_type)
54
+
55
+ os.makedirs("downloads", exist_ok=True)
56
+ file_path = f"downloads/{task_id}{extension}"
57
+
58
+ with open(file_path, "wb") as f:
59
+ for chunk in response.iter_content(chunk_size=8192):
60
+ f.write(chunk)
61
+
62
+ return file_path
63
+ except Exception as e:
64
+ raise RuntimeError(f"Download failed: {str(e)}")
65
+
66
+ def _get_extension(self, content_type: str) -> str:
67
+ type_map = {
68
+ 'image/png': '.png',
69
+ 'image/jpeg': '.jpg',
70
+ 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': '.xlsx',
71
+ 'audio/mpeg': '.mp3',
72
+ 'application/pdf': '.pdf',
73
+ 'text/x-python': '.py'
74
+ }
75
+ return type_map.get(content_type.split(';')[0], '.bin')
76
 
77
+ class ArxivSearchTool(Tool):
78
+ name = "arxiv_search"
79
+ description = "Searches academic papers on Arxiv"
80
+ inputs = {'query': {'type': 'string', 'description': 'Search query'}}
81
+ output_type = "string"
82
 
83
+ def forward(self, query: str) -> str:
84
+ try:
85
+ loader = ArxivLoader(query=query, load_max_docs=3)
86
+ docs = loader.load()
87
+ return "\n\n".join([
88
+ f"Title: {doc.metadata['Title']}\n"
89
+ f"Authors: {doc.metadata['Authors']}\n"
90
+ f"Summary: {doc.page_content[:500]}..."
91
+ for doc in docs
92
+ ])
93
+ except Exception as e:
94
+ return f"Arxiv search failed: {str(e)}"
95
 
96
+ class SpeechToTextTool(Tool):
97
+ name = "speech_to_text"
98
+ description = "Converts audio files to text"
99
+ inputs = {'audio_path': {'type': 'string', 'description': 'Path to audio file'}}
100
+ output_type = "string"
101
 
102
+ def __init__(self):
103
+ self.model = whisper.load_model("base")
104
 
105
+ def forward(self, audio_path: str) -> str:
106
+ if not os.path.exists(audio_path):
107
+ return f"File not found: {audio_path}"
108
+ return self.model.transcribe(audio_path).get("text", "")
109
 
110
  # --------------------------
111
+ # Enhanced Tools with Validation
112
  # --------------------------
113
+
114
+ class ValidatedExcelReader(Tool):
115
+ name = "excel_reader"
116
+ description = "Reads and validates Excel files"
117
+ inputs = {
118
+ 'file_path': {'type': 'string', 'description': 'Path to Excel file'},
119
+ 'schema': {'type': 'object', 'description': 'Validation schema', 'nullable': True}
 
 
 
 
 
 
 
120
  }
121
+ output_type = "string"
122
 
123
+ def forward(self, file_path: str, schema: dict = None) -> str:
124
+ df = pd.read_excel(file_path)
125
+
126
+ if schema:
127
+ validation = ValidationPipeline().validate(df, schema)
128
+ if not validation['valid']:
129
+ raise ValueError(f"Data validation failed: {validation['errors']}")
130
+
131
+ return df.to_markdown()
 
 
132
 
133
  # --------------------------
134
+ # Integrated Universal Loader
135
  # --------------------------
136
+
137
+ class UniversalLoader(Tool):
138
  def __init__(self):
139
+ self.loaders = {
140
+ 'excel': ValidatedExcelReader(),
141
+ 'audio': SpeechToTextTool(),
142
+ 'arxiv': ArxivSearchTool(),
143
+ 'web': VisitWebpageTool()
144
  }
145
 
146
+ def forward(self, source: str, task_id: str = None) -> str:
147
+ try:
148
+ if source == "attachment":
149
+ file_path = DownloadTaskAttachmentTool()(task_id)
150
+ return self._load_by_type(file_path)
151
+ return self.loaders[source].forward(task_id)
152
+ except Exception as e:
153
+ return self._fallback(source, task_id)
154
+
155
+ def _load_by_type(self, file_path: str) -> str:
156
+ ext = file_path.split('.')[-1].lower()
157
+ loader_map = {
158
+ 'xlsx': 'excel',
159
+ 'mp3': 'audio',
160
+ 'pdf': 'arxiv'
161
  }
162
+ return self.loaders[loader_map.get(ext, 'web')].forward(file_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
+ def _fallback(self, source: str, context: str) -> str:
165
+ return CrossVerifiedSearch()(f"{source} {context}")
 
 
 
 
 
 
 
 
166
 
167
  # --------------------------
168
+ # Main Agent Class (Integrated)
169
  # --------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
 
 
 
171
  class MagAgent:
172
  def __init__(self, rate_limiter: Optional[Limiter] = None):
173
  self.rate_limiter = rate_limiter
 
179
 
180
  self.tools = [
181
  UniversalLoader(),
182
+ ValidatedExcelReader(),
183
+ ArxivSearchTool(),
184
+ VisitWebpageTool(),
185
+ DownloadTaskAttachmentTool(),
186
+ SpeechToTextTool()
187
  ]
188
 
189
+ with open("prompts.yaml") as f:
190
+ self.prompt_templates = yaml.safe_load(f)
191
 
192
  self.agent = CodeAgent(
193
  model=self.model,
194
  tools=self.tools,
195
  verbosity_level=2,
196
+ prompt_templates=self.prompt_templates,
197
+ max_steps=20,
198
+ add_base_tools=False
199
  )
200
 
201
  async def __call__(self, question: str, task_id: str) -> str:
202
  try:
203
+ context = self._create_context(question, task_id)
204
+ result = await self._execute_agent(question, task_id)
205
+ return self._validate_and_format(result, context)
 
 
 
 
 
 
 
 
 
 
 
206
  except Exception as e:
207
  return self._handle_error(e, context)
208
 
209
+ # ... (keep other helper methods from previous implementation)
 
 
 
 
 
 
 
 
 
 
 
210
 
211
+ def _create_context(self, question: str, task_id: str) -> dict:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  return {
213
+ "question": question,
214
+ "task_id": task_id,
215
+ "timestamp": datetime.now().isoformat(),
216
+ "validation_checks": []
 
 
 
 
 
 
 
 
 
 
 
217
  }
 
218
 
219
+ async def _execute_agent(self, question: str, task_id: str) -> str:
220
+ return await asyncio.to_thread(
221
+ self.agent.run,
222
+ task=self._build_task_prompt(question, task_id)
223
+ )
 
 
224
 
225
+ def _validate_and_format(self, result: str, context: dict) -> str:
226
+ validation = ValidationPipeline().validate