Spaces:
Running
Running
Update agent.py
Browse files
agent.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
-
from smolagents import CodeAgent, LiteLLMModel, Tool
|
2 |
from token_bucket import Limiter, MemoryStorage
|
3 |
from tenacity import retry, stop_after_attempt, wait_exponential
|
|
|
4 |
from sentence_transformers import SentenceTransformer
|
5 |
from bs4 import BeautifulSoup
|
6 |
from datetime import datetime
|
@@ -12,165 +13,161 @@ import whisper
|
|
12 |
import yaml
|
13 |
import os
|
14 |
import re
|
|
|
15 |
|
16 |
# --------------------------
|
17 |
-
#
|
18 |
# --------------------------
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
def forward(self, source: str, task_id: str = None):
|
29 |
try:
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
elif source.startswith("http"):
|
34 |
-
return self._load_url(source)
|
35 |
except Exception as e:
|
36 |
-
return
|
37 |
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
40 |
|
41 |
-
def
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
48 |
|
49 |
-
def
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
54 |
|
55 |
-
def
|
56 |
-
|
57 |
|
58 |
-
def
|
59 |
-
|
|
|
|
|
60 |
|
61 |
# --------------------------
|
62 |
-
# Validation
|
63 |
# --------------------------
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
},
|
70 |
-
'
|
71 |
-
'check': lambda x: pd.api.types.is_datetime64_any_dtype(x),
|
72 |
-
'error': "Invalid date format detected"
|
73 |
-
},
|
74 |
-
'categorical': {
|
75 |
-
'check': lambda x: x.isin(x.dropna().unique()),
|
76 |
-
'error': "Invalid category value detected"
|
77 |
-
}
|
78 |
}
|
|
|
79 |
|
80 |
-
def
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
'confidence': 1.0 - (len(errors) / len(schema))
|
90 |
-
}
|
91 |
|
92 |
# --------------------------
|
93 |
-
#
|
94 |
# --------------------------
|
95 |
-
|
|
|
96 |
def __init__(self):
|
97 |
-
self.
|
98 |
-
|
99 |
-
'
|
100 |
-
'
|
101 |
-
'
|
102 |
}
|
103 |
|
104 |
-
def
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
}
|
110 |
-
return
|
111 |
-
|
112 |
-
# --------------------------
|
113 |
-
# Temporal Search
|
114 |
-
# --------------------------
|
115 |
-
class HistoricalSearch:
|
116 |
-
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10))
|
117 |
-
def get_historical_content(self, url: str, target_date: str):
|
118 |
-
return requests.get(
|
119 |
-
f"http://archive.org/wayback/available?url={url}×tamp={target_date}"
|
120 |
-
).json()
|
121 |
-
|
122 |
-
# --------------------------
|
123 |
-
# Enhanced Excel Reader
|
124 |
-
# --------------------------
|
125 |
-
class EnhancedExcelReader(Tool):
|
126 |
-
def forward(self, path: str):
|
127 |
-
df = pd.read_excel(path)
|
128 |
-
validation = ValidationPipeline().validate(df, self._detect_schema(df))
|
129 |
-
if not validation['valid']:
|
130 |
-
raise ValueError(f"Data validation failed: {validation['errors']}")
|
131 |
-
return df.to_markdown()
|
132 |
|
133 |
-
def
|
134 |
-
|
135 |
-
for col in df.columns:
|
136 |
-
dtype = 'categorical'
|
137 |
-
if pd.api.types.is_numeric_dtype(df[col]):
|
138 |
-
dtype = 'numeric'
|
139 |
-
elif pd.api.types.is_datetime64_any_dtype(df[col]):
|
140 |
-
dtype = 'temporal'
|
141 |
-
schema[col] = {'type': dtype}
|
142 |
-
return schema
|
143 |
|
144 |
# --------------------------
|
145 |
-
#
|
146 |
# --------------------------
|
147 |
-
class CrossVerifiedSearch:
|
148 |
-
SOURCES = [
|
149 |
-
DuckDuckGoSearchTool(),
|
150 |
-
WikipediaSearchTool(),
|
151 |
-
ArxivSearchTool()
|
152 |
-
]
|
153 |
-
|
154 |
-
def __call__(self, query: str):
|
155 |
-
results = []
|
156 |
-
for source in self.SOURCES:
|
157 |
-
try:
|
158 |
-
results.append(source(query))
|
159 |
-
except Exception as e:
|
160 |
-
continue
|
161 |
-
return self._consensus(results)
|
162 |
-
|
163 |
-
def _consensus(self, results):
|
164 |
-
# Simple majority voting implementation
|
165 |
-
counts = {}
|
166 |
-
for result in results:
|
167 |
-
key = str(result)[:100] # Simple hash for demo
|
168 |
-
counts[key] = counts.get(key, 0) + 1
|
169 |
-
return max(counts, key=counts.get)
|
170 |
|
171 |
-
# --------------------------
|
172 |
-
# Main Agent Class
|
173 |
-
# --------------------------
|
174 |
class MagAgent:
|
175 |
def __init__(self, rate_limiter: Optional[Limiter] = None):
|
176 |
self.rate_limiter = rate_limiter
|
@@ -182,98 +179,48 @@ class MagAgent:
|
|
182 |
|
183 |
self.tools = [
|
184 |
UniversalLoader(),
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
|
|
189 |
]
|
190 |
|
191 |
-
|
192 |
-
|
193 |
|
194 |
self.agent = CodeAgent(
|
195 |
model=self.model,
|
196 |
tools=self.tools,
|
197 |
verbosity_level=2,
|
198 |
-
|
199 |
-
max_steps=20
|
|
|
200 |
)
|
201 |
|
202 |
async def __call__(self, question: str, task_id: str) -> str:
|
203 |
try:
|
204 |
-
context =
|
205 |
-
|
206 |
-
|
207 |
-
"validation_checks": []
|
208 |
-
}
|
209 |
-
|
210 |
-
result = await asyncio.to_thread(
|
211 |
-
self.agent.run,
|
212 |
-
task=self._build_task_prompt(question, task_id)
|
213 |
-
)
|
214 |
-
|
215 |
-
validated = self._validate_result(result, context)
|
216 |
-
return self._format_output(validated)
|
217 |
-
|
218 |
except Exception as e:
|
219 |
return self._handle_error(e, context)
|
220 |
|
221 |
-
|
222 |
-
base_prompt = self.prompt_templates['base']
|
223 |
-
domain = ToolRouter().route(question)
|
224 |
-
return f"""
|
225 |
-
{base_prompt}
|
226 |
-
|
227 |
-
**Domain Classification**: {domain}
|
228 |
-
**Required Validation**: {self._get_validation_requirements(domain)}
|
229 |
-
|
230 |
-
Question: {question}
|
231 |
-
{self._attachment_prompt(task_id)}
|
232 |
-
"""
|
233 |
|
234 |
-
def
|
235 |
-
validation_rules = {
|
236 |
-
'numeric': r'\d+',
|
237 |
-
'temporal': r'\d{4}-\d{2}-\d{2}',
|
238 |
-
'categorical': r'^[A-Za-z]+$'
|
239 |
-
}
|
240 |
-
|
241 |
-
validations = {}
|
242 |
-
for v_type, pattern in validation_rules.items():
|
243 |
-
match = re.search(pattern, result)
|
244 |
-
validations[v_type] = bool(match)
|
245 |
-
|
246 |
-
confidence = sum(validations.values()) / len(validations)
|
247 |
-
context['validation_checks'] = validations
|
248 |
-
|
249 |
return {
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
def _format_output(self, validated: dict) -> str:
|
256 |
-
if validated['confidence'] < 0.7:
|
257 |
-
return "Unable to verify answer with sufficient confidence"
|
258 |
-
return validated['result']
|
259 |
-
|
260 |
-
def _handle_error(self, error: Exception, context: dict) -> str:
|
261 |
-
error_info = {
|
262 |
-
"type": type(error).__name__,
|
263 |
-
"message": str(error),
|
264 |
-
"context": context
|
265 |
}
|
266 |
-
return json.dumps(error_info)
|
267 |
|
268 |
-
def
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
}
|
274 |
-
return requirements.get(domain, "Standard fact verification")
|
275 |
|
276 |
-
def
|
277 |
-
|
278 |
-
return f"Attachment available with task_id: {task_id}"
|
279 |
-
return "No attachments provided"
|
|
|
1 |
+
from smolagents import CodeAgent, LiteLLMModel, Tool
|
2 |
from token_bucket import Limiter, MemoryStorage
|
3 |
from tenacity import retry, stop_after_attempt, wait_exponential
|
4 |
+
from langchain_community.document_loaders import ArxivLoader
|
5 |
from sentence_transformers import SentenceTransformer
|
6 |
from bs4 import BeautifulSoup
|
7 |
from datetime import datetime
|
|
|
13 |
import yaml
|
14 |
import os
|
15 |
import re
|
16 |
+
import json
|
17 |
|
18 |
# --------------------------
|
19 |
+
# Core Tools from Previous Implementation
|
20 |
# --------------------------
|
21 |
+
|
22 |
+
class VisitWebpageTool(Tool):
|
23 |
+
name = "visit_webpage"
|
24 |
+
description = "Visits a webpage and returns its content as markdown"
|
25 |
+
inputs = {'url': {'type': 'string', 'description': 'The URL to visit'}}
|
26 |
+
output_type = "string"
|
27 |
+
|
28 |
+
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10))
|
29 |
+
def forward(self, url: str) -> str:
|
|
|
30 |
try:
|
31 |
+
response = requests.get(url, timeout=30)
|
32 |
+
response.raise_for_status()
|
33 |
+
return markdownify(response.text).strip()
|
|
|
|
|
34 |
except Exception as e:
|
35 |
+
return f"Error fetching webpage: {str(e)}"
|
36 |
|
37 |
+
class DownloadTaskAttachmentTool(Tool):
|
38 |
+
name = "download_file"
|
39 |
+
description = "Downloads files from the task API"
|
40 |
+
inputs = {'task_id': {'type': 'string', 'description': 'The task ID to download'}}
|
41 |
+
output_type = "string"
|
42 |
|
43 |
+
def forward(self, task_id: str) -> str:
|
44 |
+
api_url = os.getenv("TASK_API_URL", "https://agents-course-unit4-scoring.hf.space")
|
45 |
+
file_url = f"{api_url}/files/{task_id}"
|
46 |
+
|
47 |
+
try:
|
48 |
+
response = requests.get(file_url, stream=True, timeout=30)
|
49 |
+
response.raise_for_status()
|
50 |
+
|
51 |
+
# File type detection
|
52 |
+
content_type = response.headers.get('Content-Type', '')
|
53 |
+
extension = self._get_extension(content_type)
|
54 |
+
|
55 |
+
os.makedirs("downloads", exist_ok=True)
|
56 |
+
file_path = f"downloads/{task_id}{extension}"
|
57 |
+
|
58 |
+
with open(file_path, "wb") as f:
|
59 |
+
for chunk in response.iter_content(chunk_size=8192):
|
60 |
+
f.write(chunk)
|
61 |
+
|
62 |
+
return file_path
|
63 |
+
except Exception as e:
|
64 |
+
raise RuntimeError(f"Download failed: {str(e)}")
|
65 |
+
|
66 |
+
def _get_extension(self, content_type: str) -> str:
|
67 |
+
type_map = {
|
68 |
+
'image/png': '.png',
|
69 |
+
'image/jpeg': '.jpg',
|
70 |
+
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': '.xlsx',
|
71 |
+
'audio/mpeg': '.mp3',
|
72 |
+
'application/pdf': '.pdf',
|
73 |
+
'text/x-python': '.py'
|
74 |
+
}
|
75 |
+
return type_map.get(content_type.split(';')[0], '.bin')
|
76 |
|
77 |
+
class ArxivSearchTool(Tool):
|
78 |
+
name = "arxiv_search"
|
79 |
+
description = "Searches academic papers on Arxiv"
|
80 |
+
inputs = {'query': {'type': 'string', 'description': 'Search query'}}
|
81 |
+
output_type = "string"
|
82 |
|
83 |
+
def forward(self, query: str) -> str:
|
84 |
+
try:
|
85 |
+
loader = ArxivLoader(query=query, load_max_docs=3)
|
86 |
+
docs = loader.load()
|
87 |
+
return "\n\n".join([
|
88 |
+
f"Title: {doc.metadata['Title']}\n"
|
89 |
+
f"Authors: {doc.metadata['Authors']}\n"
|
90 |
+
f"Summary: {doc.page_content[:500]}..."
|
91 |
+
for doc in docs
|
92 |
+
])
|
93 |
+
except Exception as e:
|
94 |
+
return f"Arxiv search failed: {str(e)}"
|
95 |
|
96 |
+
class SpeechToTextTool(Tool):
|
97 |
+
name = "speech_to_text"
|
98 |
+
description = "Converts audio files to text"
|
99 |
+
inputs = {'audio_path': {'type': 'string', 'description': 'Path to audio file'}}
|
100 |
+
output_type = "string"
|
101 |
|
102 |
+
def __init__(self):
|
103 |
+
self.model = whisper.load_model("base")
|
104 |
|
105 |
+
def forward(self, audio_path: str) -> str:
|
106 |
+
if not os.path.exists(audio_path):
|
107 |
+
return f"File not found: {audio_path}"
|
108 |
+
return self.model.transcribe(audio_path).get("text", "")
|
109 |
|
110 |
# --------------------------
|
111 |
+
# Enhanced Tools with Validation
|
112 |
# --------------------------
|
113 |
+
|
114 |
+
class ValidatedExcelReader(Tool):
|
115 |
+
name = "excel_reader"
|
116 |
+
description = "Reads and validates Excel files"
|
117 |
+
inputs = {
|
118 |
+
'file_path': {'type': 'string', 'description': 'Path to Excel file'},
|
119 |
+
'schema': {'type': 'object', 'description': 'Validation schema', 'nullable': True}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
}
|
121 |
+
output_type = "string"
|
122 |
|
123 |
+
def forward(self, file_path: str, schema: dict = None) -> str:
|
124 |
+
df = pd.read_excel(file_path)
|
125 |
+
|
126 |
+
if schema:
|
127 |
+
validation = ValidationPipeline().validate(df, schema)
|
128 |
+
if not validation['valid']:
|
129 |
+
raise ValueError(f"Data validation failed: {validation['errors']}")
|
130 |
+
|
131 |
+
return df.to_markdown()
|
|
|
|
|
132 |
|
133 |
# --------------------------
|
134 |
+
# Integrated Universal Loader
|
135 |
# --------------------------
|
136 |
+
|
137 |
+
class UniversalLoader(Tool):
|
138 |
def __init__(self):
|
139 |
+
self.loaders = {
|
140 |
+
'excel': ValidatedExcelReader(),
|
141 |
+
'audio': SpeechToTextTool(),
|
142 |
+
'arxiv': ArxivSearchTool(),
|
143 |
+
'web': VisitWebpageTool()
|
144 |
}
|
145 |
|
146 |
+
def forward(self, source: str, task_id: str = None) -> str:
|
147 |
+
try:
|
148 |
+
if source == "attachment":
|
149 |
+
file_path = DownloadTaskAttachmentTool()(task_id)
|
150 |
+
return self._load_by_type(file_path)
|
151 |
+
return self.loaders[source].forward(task_id)
|
152 |
+
except Exception as e:
|
153 |
+
return self._fallback(source, task_id)
|
154 |
+
|
155 |
+
def _load_by_type(self, file_path: str) -> str:
|
156 |
+
ext = file_path.split('.')[-1].lower()
|
157 |
+
loader_map = {
|
158 |
+
'xlsx': 'excel',
|
159 |
+
'mp3': 'audio',
|
160 |
+
'pdf': 'arxiv'
|
161 |
}
|
162 |
+
return self.loaders[loader_map.get(ext, 'web')].forward(file_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
|
164 |
+
def _fallback(self, source: str, context: str) -> str:
|
165 |
+
return CrossVerifiedSearch()(f"{source} {context}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
|
167 |
# --------------------------
|
168 |
+
# Main Agent Class (Integrated)
|
169 |
# --------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
|
|
|
|
|
|
|
171 |
class MagAgent:
|
172 |
def __init__(self, rate_limiter: Optional[Limiter] = None):
|
173 |
self.rate_limiter = rate_limiter
|
|
|
179 |
|
180 |
self.tools = [
|
181 |
UniversalLoader(),
|
182 |
+
ValidatedExcelReader(),
|
183 |
+
ArxivSearchTool(),
|
184 |
+
VisitWebpageTool(),
|
185 |
+
DownloadTaskAttachmentTool(),
|
186 |
+
SpeechToTextTool()
|
187 |
]
|
188 |
|
189 |
+
with open("prompts.yaml") as f:
|
190 |
+
self.prompt_templates = yaml.safe_load(f)
|
191 |
|
192 |
self.agent = CodeAgent(
|
193 |
model=self.model,
|
194 |
tools=self.tools,
|
195 |
verbosity_level=2,
|
196 |
+
prompt_templates=self.prompt_templates,
|
197 |
+
max_steps=20,
|
198 |
+
add_base_tools=False
|
199 |
)
|
200 |
|
201 |
async def __call__(self, question: str, task_id: str) -> str:
|
202 |
try:
|
203 |
+
context = self._create_context(question, task_id)
|
204 |
+
result = await self._execute_agent(question, task_id)
|
205 |
+
return self._validate_and_format(result, context)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
except Exception as e:
|
207 |
return self._handle_error(e, context)
|
208 |
|
209 |
+
# ... (keep other helper methods from previous implementation)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
|
211 |
+
def _create_context(self, question: str, task_id: str) -> dict:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
return {
|
213 |
+
"question": question,
|
214 |
+
"task_id": task_id,
|
215 |
+
"timestamp": datetime.now().isoformat(),
|
216 |
+
"validation_checks": []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
}
|
|
|
218 |
|
219 |
+
async def _execute_agent(self, question: str, task_id: str) -> str:
|
220 |
+
return await asyncio.to_thread(
|
221 |
+
self.agent.run,
|
222 |
+
task=self._build_task_prompt(question, task_id)
|
223 |
+
)
|
|
|
|
|
224 |
|
225 |
+
def _validate_and_format(self, result: str, context: dict) -> str:
|
226 |
+
validation = ValidationPipeline().validate
|
|
|
|