SergeyO7 commited on
Commit
93e7a8c
·
verified ·
1 Parent(s): 5a45044

Create agent.py

Browse files
Files changed (1) hide show
  1. agent.py +422 -0
agent.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from smolagents import CodeAgent, LiteLLMModel, tool, Tool, load_tool, DuckDuckGoSearchTool, WikipediaSearchTool #, HfApiModel, OpenAIServerModel
2
+ import asyncio
3
+ import os
4
+ import re
5
+ import pandas as pd
6
+ from typing import Optional
7
+ from token_bucket import Limiter, MemoryStorage
8
+ import yaml
9
+ from PIL import Image, ImageOps
10
+ import requests
11
+ from io import BytesIO
12
+ from markdownify import markdownify
13
+ import whisper
14
+
15
+ import time
16
+ import shutil
17
+ import traceback
18
+
19
+
20
+ @tool
21
+ def GoogleSearchTool(query: str) -> str:
22
+ """Tool for performing Google searches using Custom Search JSON API
23
+ Args:
24
+ query (str): Search query string
25
+ Returns:
26
+ str: Formatted search results
27
+ """
28
+ cse_id = os.environ.get("GOOGLE_CSE_ID")
29
+ if not api_key or not cse_id:
30
+ raise ValueError("GOOGLE_API_KEY and GOOGLE_CSE_ID must be set in environment variables.")
31
+ url = "https://www.googleapis.com/customsearch/v1"
32
+ params = {
33
+ "key": api_key,
34
+ "cx": cse_id,
35
+ "q": query,
36
+ "num": 5 # Number of results to return
37
+ }
38
+ try:
39
+ response = requests.get(url, params=params)
40
+ response.raise_for_status()
41
+ results = response.json().get("items", [])
42
+ return "\n".join([f"{item['title']}: {item['link']}" for item in results]) or "No results found."
43
+ except Exception as e:
44
+ return f"Error performing Google search: {str(e)}"
45
+
46
+ #@tool
47
+
48
+ #def ImageAnalysisTool(question: str, model: LiteLLMModel) -> str:
49
+ # """Tool for analyzing images mentioned in the question.
50
+ # Args:
51
+ # question (str): The question text which may contain an image URL.
52
+ # Returns:
53
+ # str: Image description or error message.
54
+ # """
55
+ # # Extract URL from question using regex
56
+ # url_pattern = r'https?://\S+'
57
+ #
58
+ # match = re.search(url_pattern, question)
59
+ # if not match:
60
+ # return "No image URL found in the question."
61
+ # image_url = match.group(0)
62
+ #
63
+ # headers = {
64
+ # "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0 Safari/537.36"
65
+ # }
66
+ # try:
67
+ # response = requests.get(image_url, headers=headers)
68
+
69
+ # response.raise_for_status()
70
+ # image = Image.open(BytesIO(response.content)).convert("RGB")
71
+ # except Exception as e:
72
+ # return f"Error fetching image: {e}"
73
+ #
74
+ # agent = CodeAgent(
75
+ # tools=[],
76
+ # model=model,
77
+ # max_steps=10,
78
+ # verbosity_level=2
79
+ # )
80
+ #
81
+ # response = agent.run(
82
+ # "Describe in details the chess position you see in the image.",
83
+ # images=[image]
84
+ # )
85
+ #
86
+ # return f"The image description: '{response}'"
87
+
88
+ class VisitWebpageTool(Tool):
89
+ name = "visit_webpage"
90
+ description = "Visits a webpage at the given url and reads its content as a markdown string. Use this to browse webpages."
91
+ inputs = {'url': {'type': 'string', 'description': 'The url of the webpage to visit.'}}
92
+ output_type = "string"
93
+
94
+ def forward(self, url: str) -> str:
95
+ try:
96
+ response = requests.get(url, timeout=20)
97
+ response.raise_for_status()
98
+ markdown_content = markdownify(response.text).strip()
99
+ markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content)
100
+ from smolagents.utils import truncate_content
101
+ return truncate_content(markdown_content, 10000)
102
+ except requests.exceptions.Timeout:
103
+ return "The request timed out. Please try again later or check the URL."
104
+ except requests.exceptions.RequestException as e:
105
+ return f"Error fetching the webpage: {str(e)}"
106
+ except Exception as e:
107
+ return f"An unexpected error occurred: {str(e)}"
108
+
109
+ def __init__(self, *args, **kwargs):
110
+ self.is_initialized = False
111
+
112
+ class DownloadTaskAttachmentTool(Tool):
113
+ name = "download_file"
114
+ description = "Downloads the file attached to the task ID and returns the local file path. Supports Excel (.xlsx), image (.png, .jpg), audio (.mp3), PDF (.pdf), and Python (.py) files."
115
+ inputs = {'task_id': {'type': 'string', 'description': 'The task id to download attachment from.'}}
116
+ output_type = "string"
117
+ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
118
+
119
+ def __init__(self, rate_limiter: Optional[Limiter] = None, default_api_url: str = DEFAULT_API_URL, *args, **kwargs):
120
+ self.is_initialized = False
121
+ self.rate_limiter = rate_limiter
122
+ self.default_api_url = default_api_url
123
+
124
+ def forward(self, task_id: str) -> str:
125
+ file_url = f"{self.default_api_url}/files/{task_id}"
126
+ print(f"Downloading file for task ID {task_id} from {file_url}...")
127
+ try:
128
+ if self.rate_limiter:
129
+ while not self.rate_limiter.consume(1):
130
+ print(f"Rate limit reached for downloading file for task {task_id}. Waiting...")
131
+ time.sleep(60 / 15) # Assuming 15 RPM
132
+ response = requests.get(file_url, stream=True, timeout=15)
133
+ response.raise_for_status()
134
+
135
+ # Determine file extension based on Content-Type
136
+ content_type = response.headers.get('Content-Type', '').lower()
137
+ if 'image/png' in content_type:
138
+ extension = '.png'
139
+ elif 'image/jpeg' in content_type:
140
+ extension = '.jpg'
141
+ elif 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' in content_type:
142
+ extension = '.xlsx'
143
+ elif 'audio/mpeg' in content_type:
144
+ extension = '.mp3'
145
+ elif 'application/pdf' in content_type:
146
+ extension = '.pdf'
147
+ elif 'text/x-python' in content_type:
148
+ extension = '.py'
149
+ else:
150
+ return f"Error: Unsupported file type {content_type} for task {task_id}. Try using visit_webpage or web_search if the content is online."
151
+
152
+ local_file_path = f"downloads/{task_id}{extension}"
153
+ os.makedirs("downloads", exist_ok=True)
154
+ with open(local_file_path, "wb") as file:
155
+ for chunk in response.iter_content(chunk_size=8192):
156
+ file.write(chunk)
157
+ print(f"File downloaded successfully: {local_file_path}")
158
+ return local_file_path
159
+ except requests.exceptions.HTTPError as e:
160
+ if e.response.status_code == 429:
161
+ return f"Error: Rate limit exceeded for task {task_id}. Try again later."
162
+ return f"Error downloading file for task {task_id}: {str(e)}"
163
+ except requests.exceptions.RequestException as e:
164
+ return f"Error downloading file for task {task_id}: {str(e)}"
165
+
166
+ class SpeechToTextTool(Tool):
167
+ name = "speech_to_text"
168
+ description = (
169
+ "Converts an audio file to text using OpenAI Whisper."
170
+ )
171
+ inputs = {
172
+ "audio_path": {"type": "string", "description": "Path to audio file (.mp3, .wav)"},
173
+ }
174
+ output_type = "string"
175
+
176
+ def __init__(self):
177
+ super().__init__()
178
+ self.model = whisper.load_model("base")
179
+
180
+ def forward(self, audio_path: str) -> str:
181
+ if not os.path.exists(audio_path):
182
+ return f"Error: File not found at {audio_path}"
183
+ result = self.model.transcribe(audio_path)
184
+ return result.get("text", "")
185
+
186
+ class ExcelReaderTool(Tool):
187
+ name = "excel_reader"
188
+
189
+ description = """
190
+ This tool reads and processes Excel files (.xlsx, .xls).
191
+ It can extract data, calculate statistics, and perform data analysis on spreadsheets.
192
+ """
193
+ inputs = {
194
+ "excel_path": {
195
+ "type": "string"
196
+ ,
197
+ "description": "The path to the Excel file to read",
198
+ },
199
+ "sheet_name": {
200
+ "type": "string",
201
+
202
+ "description": "The name of the sheet to read (optional, defaults to first sheet)",
203
+ "nullable": True
204
+ }
205
+ }
206
+ output_type = "string"
207
+
208
+ def forward(self, excel_path: str, sheet_name: str = None) -> str:
209
+ """
210
+ Reads and processes the given Excel file.
211
+ """
212
+ try:
213
+ # Check if the file exists
214
+ if not os.path.exists(excel_path):
215
+ return f"Error: Excel file not found at {excel_path}"
216
+
217
+ import pandas as pd
218
+
219
+ # Read the Excel file
220
+ if sheet_name:
221
+ df = pd.read_excel(excel_path, sheet_name=sheet_name)
222
+ else:
223
+ df = pd.read_excel(excel_path)
224
+
225
+ # Get basic info about the data
226
+ info = {
227
+ "shape": df.shape,
228
+ "columns": list(df.columns),
229
+ "dtypes": df.dtypes.to_dict(),
230
+ "head": df.head(5).to_dict()
231
+ }
232
+
233
+ # Return formatted info
234
+ result = f"Excel file: {excel_path}\n"
235
+ result += f"Shape: {info['shape'][0]} rows × {info['shape'][1]} columns\n\n"
236
+ result += "Columns:\n"
237
+ for col in info['columns']:
238
+ result += f"- {col} ({info['dtypes'].get(col)})\n"
239
+
240
+ result += "\nPreview (first 5 rows):\n"
241
+ result += df.head(5).to_string()
242
+
243
+ return result
244
+
245
+ except Exception as e:
246
+ return f"Error reading Excel file: {str(e)}"
247
+
248
+
249
+
250
+
251
+ class DownloadImageTool(Tool):
252
+ name = "download_chess_image"
253
+ description = "Downloads chess position image from task ID"
254
+ inputs = {'task_id': {'type': 'string'}}
255
+ output_type = "string"
256
+
257
+ def forward(self, task_id: str) -> str:
258
+ try:
259
+ response = requests.get(
260
+ f"https://agents-course-unit4-scoring.hf.space/files/{task_id}",
261
+ stream=True
262
+ )
263
+ response.raise_for_status()
264
+
265
+ img_path = f"chess_{task_id}.png"
266
+ with open(img_path, "wb") as f:
267
+ for chunk in response.iter_content(8192):
268
+ f.write(chunk)
269
+ return img_path
270
+ except Exception as e:
271
+ raise RuntimeError(f"Image download failed: {str(e)}")
272
+
273
+
274
+
275
+ class ChessEngineTool(Tool):
276
+ import chess
277
+ import chess.engine
278
+ name = "stockfish_analysis"
279
+ description = "Analyzes chess position using Stockfish"
280
+ inputs = {'fen': {'type': 'string'}}
281
+ output_type = "string"
282
+
283
+ def forward(self, fen: str) -> str:
284
+ try:
285
+ board = chess.Board(fen)
286
+ engine = chess.engine.SimpleEngine.popen_uci("stockfish")
287
+ result = engine.play(board, chess.engine.Limit(time=2.0))
288
+ engine.quit()
289
+ return board.san(result.move)
290
+ except Exception as e:
291
+ return f"Engine error: {str(e)}"
292
+
293
+ async def analyze_position(self, task_id: str):
294
+ try:
295
+ # Step 1: Download image
296
+ img_path = await self.tools[0](task_id)
297
+
298
+ # Step 2: Get multimodal analysis
299
+ response = await self.model.acreate(
300
+ messages=[{
301
+ "role": "user",
302
+ "content": [
303
+ {"type": "text", "text": """Analyze this chess position.
304
+ It's black's turn. Provide the winning move in algebraic notation.
305
+ Respond ONLY with the move, nothing else."""},
306
+ {"type": "image_url", "image_url": {"url": f"file://{img_path}"}}
307
+ ]
308
+ }],
309
+ temperature=0.1
310
+ )
311
+
312
+ return response.choices[0].message.content
313
+
314
+ except Exception as e:
315
+ return f"Analysis failed: {str(e)}"
316
+
317
+
318
+
319
+
320
+ class PythonCodeReaderTool(Tool):
321
+ name = "read_python_code"
322
+ description = "Reads a Python (.py) file and returns its content as a string."
323
+ inputs = {
324
+ "file_path": {"type": "string", "description": "The path to the Python file to read"}
325
+ }
326
+ output_type = "string"
327
+
328
+ def forward(self, file_path: str) -> str:
329
+ try:
330
+ if not os.path.exists(file_path):
331
+ return f"Error: Python file not found at {file_path}"
332
+ with open(file_path, "r", encoding="utf-8") as file:
333
+ content = file.read()
334
+ return content
335
+ except Exception as e:
336
+ return f"Error reading Python file: {str(e)}"
337
+
338
+ class MagAgent:
339
+ def __init__(self, rate_limiter: Optional[Limiter] = None):
340
+ """Initialize the MagAgent with search tools."""
341
+ self.rate_limiter = rate_limiter
342
+
343
+ print("Initializing MagAgent with search tools...")
344
+ # model = LiteLLMModel(
345
+ # model_id="gemini/gemini-2.0-flash-preview-image-generation",
346
+ # api_key= os.environ.get("GEMINI_KEY"),
347
+ # max_tokens=8192
348
+ # )
349
+
350
+ self.model = LiteLLMModel(
351
+ model_id="gemini/gemini-1.5-flash",
352
+ api_key=os.environ.get("GEMINI_KEY"),
353
+ api_base="https://generativelanguage.googleapis.com/v1beta",
354
+ max_tokens=2048
355
+ )
356
+
357
+ self.tools = [
358
+ self.DownloadImageTool(),
359
+ self.ChessEngineTool()
360
+ ]
361
+
362
+
363
+ # Load prompt templates
364
+ with open("prompts.yaml", 'r') as stream:
365
+ prompt_templates = yaml.safe_load(stream)
366
+
367
+ # Initialize rate limiter for DuckDuckGoSearchTool
368
+ search_rate_limiter = Limiter(rate=30/60, capacity=30, storage=MemoryStorage()) if not rate_limiter else rate_limiter
369
+
370
+ self.agent = CodeAgent(
371
+ model= model,
372
+ tools=[
373
+ DownloadTaskAttachmentTool(rate_limiter=rate_limiter),
374
+ # DuckDuckGoSearchTool(),
375
+ # WikipediaSearchTool(),
376
+ SpeechToTextTool(),
377
+ ExcelReaderTool(),
378
+ VisitWebpageTool(),
379
+ PythonCodeReaderTool(),
380
+ PNG2FENTool,
381
+ ChessEngineTool(),
382
+ # GoogleSearchTool,
383
+ # ImageAnalysisTool,
384
+ ],
385
+ verbosity_level=2,
386
+ prompt_templates=prompt_templates,
387
+ add_base_tools=True,
388
+ max_steps=15
389
+ )
390
+ print("MagAgent initialized.")
391
+
392
+ async def __call__(self, question: str, task_id: str) -> str:
393
+ """Process a question asynchronously using the MagAgent."""
394
+ print(f"MagAgent received question (first 50 chars): {question[:50]}... Task ID: {task_id}")
395
+ try:
396
+ if self.rate_limiter:
397
+ while not self.rate_limiter.consume(1):
398
+ print(f"Rate limit reached for task {task_id}. Waiting...")
399
+ await asyncio.sleep(60 / 15) # Assuming 15 RPM
400
+ # Include task_id in the task prompt to guide the agent
401
+ task = (
402
+ # f"Answer the following question accurately and concisely: \n"
403
+ f"{question} \n"
404
+ f"If the question references an attachment, use tool to download it with task_id: {task_id}\n"
405
+ # f"Return the answer as a string."
406
+ )
407
+ print(f"Calling agent.run for task {task_id}...")
408
+ response = await asyncio.to_thread(
409
+ self.agent.run,
410
+ task=task
411
+ )
412
+ print(f"Agent.run completed for task {task_id}.")
413
+ response = str(response)
414
+ if not response:
415
+ print(f"No answer found for task {task_id}.")
416
+ response = "No answer found."
417
+ print(f"MagAgent response: {response[:50]}...")
418
+ return response
419
+ except Exception as e:
420
+ error_msg = f"Error processing question for task {task_id}: {str(e)}. Check API key or network connectivity."
421
+ print(error_msg)
422
+ return error_msg