SergeyO7 commited on
Commit
90940ae
·
verified ·
1 Parent(s): 85f8adb

Create agent.py

Browse files
Files changed (1) hide show
  1. agent.py +399 -0
agent.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from smolagents import CodeAgent, ToolCallingAgent, LiteLLMModel, tool, Tool, load_tool, WebSearchTool, DuckDuckGoSearchTool #, WikipediaSearchTool
2
+ import asyncio
3
+ import os
4
+ import re
5
+ import pandas as pd
6
+ from typing import Optional
7
+ from token_bucket import Limiter, MemoryStorage
8
+ import yaml
9
+ from PIL import Image, ImageOps
10
+ import requests
11
+ from io import BytesIO
12
+ from markdownify import markdownify
13
+ import whisper
14
+ import time
15
+ import shutil
16
+ import traceback
17
+ from langchain_community.document_loaders import ArxivLoader
18
+ import logging
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ @tool
23
+ def search_arxiv(query: str) -> str:
24
+ """Search Arxiv for a query and return maximum 3 result.
25
+
26
+ Args:
27
+ query: The search query.
28
+ Returns:
29
+ str: Formatted search results
30
+ """
31
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
32
+ formatted_search_docs = "\n\n---\n\n".join(
33
+ [
34
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
35
+ for doc in search_docs
36
+ ])
37
+ return {"arxiv_results": formatted_search_docs}
38
+
39
+ from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception
40
+ import requests
41
+
42
+ def is_429_error(exception):
43
+ return isinstance(exception, requests.exceptions.HTTPError) and exception.response.status_code == 429
44
+
45
+ class VisitWebpageTool(Tool):
46
+ name = "visit_webpage"
47
+ description = "Visits a webpage at the given url and reads its content as a markdown string. Use this to browse webpages."
48
+ inputs = {'url': {'type': 'string', 'description': 'The url of the webpage to visit.'}}
49
+ output_type = "string"
50
+
51
+ @retry(
52
+ stop=stop_after_attempt(3),
53
+ wait=wait_exponential(multiplier=1, min=4, max=10),
54
+ retry=retry_if_exception(is_429_error)
55
+ )
56
+ def forward(self, url: str) -> str:
57
+ try:
58
+ response = requests.get(url, timeout=50)
59
+ response.raise_for_status()
60
+ markdown_content = markdownify(response.text).strip()
61
+ markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content)
62
+ #from smolagents.utils import truncate_content
63
+ #return truncate_content(markdown_content, 10000)
64
+ return markdown_content
65
+ except requests.exceptions.HTTPError as e:
66
+ if e.response.status_code == 429:
67
+ raise # Retry on 429
68
+ return f"Error fetching the webpage: {str(e)}"
69
+ except requests.exceptions.Timeout:
70
+ return "The request timed out. Please try again later or check the URL."
71
+ except requests.exceptions.RequestException as e:
72
+ return f"Error fetching the webpage: {str(e)}"
73
+ except Exception as e:
74
+ return f"An unexpected error occurred: {str(e)}"
75
+
76
+ def __init__(self, *args, **kwargs):
77
+ self.is_initialized = False
78
+
79
+ class SpeechToTextTool(Tool):
80
+ name = "speech_to_text"
81
+ description = (
82
+ "Converts an audio file to text using OpenAI Whisper."
83
+ )
84
+ inputs = {
85
+ "audio_path": {"type": "string", "description": "Path to audio file (.mp3, .wav)"},
86
+ }
87
+ output_type = "string"
88
+
89
+ def __init__(self):
90
+ super().__init__()
91
+ self.model = whisper.load_model("base")
92
+
93
+ def forward(self, audio_path: str) -> str:
94
+ if not os.path.exists(audio_path):
95
+ return f"Error: File not found at {audio_path}"
96
+ try:
97
+ print(f"Starting transcription for {audio_path}...")
98
+ result = self.model.transcribe(audio_path)
99
+ print(f"Transcription completed for {audio_path}.")
100
+ return result.get("text", "")
101
+ except Exception as e:
102
+ return f"Error processing audio file: {str(e)}"
103
+
104
+ class ExcelReaderTool(Tool):
105
+ name = "excel_reader"
106
+ description = """
107
+ This tool reads and processes Excel files (.xlsx, .xls).
108
+ It can extract data, calculate statistics, and perform data analysis on spreadsheets.
109
+ """
110
+ inputs = {
111
+ "excel_path": {
112
+ "type": "string",
113
+ "description": "The path to the Excel file to read",
114
+ },
115
+ "sheet_name": {
116
+ "type": "string",
117
+ "description": "The name of the sheet to read (optional, defaults to first sheet)",
118
+ "nullable": True
119
+ }
120
+ }
121
+ output_type = "string"
122
+
123
+ def forward(self, excel_path: str, sheet_name: str = None) -> str:
124
+ try:
125
+ if not os.path.exists(excel_path):
126
+ return f"Error: Excel file not found at {excel_path}"
127
+ import pandas as pd
128
+ if sheet_name:
129
+ df = pd.read_excel(excel_path, sheet_name=sheet_name)
130
+ else:
131
+ df = pd.read_excel(excel_path)
132
+ info = {
133
+ "shape": df.shape,
134
+ "columns": list(df.columns),
135
+ "dtypes": df.dtypes.to_dict(),
136
+ "head": df.head(5).to_dict()
137
+ }
138
+ result = f"Excel file: {excel_path}\n"
139
+ result += f"Shape: {info['shape'][0]} rows × {info['shape'][1]} columns\n\n"
140
+ result += "Columns:\n"
141
+ for col in info['columns']:
142
+ result += f"- {col} ({info['dtypes'].get(col)})\n"
143
+ result += "\nPreview (first 5 rows):\n"
144
+ result += df.head(5).to_string()
145
+ return result
146
+ except Exception as e:
147
+ return f"Error reading Excel file: {str(e)}"
148
+
149
+ class PythonCodeReaderTool(Tool):
150
+ name = "read_python_code"
151
+ description = "Reads a Python (.py) file and returns its content as a string."
152
+ inputs = {
153
+ "file_path": {"type": "string", "description": "The path to the Python file to read"}
154
+ }
155
+ output_type = "string"
156
+
157
+ def forward(self, file_path: str) -> str:
158
+ try:
159
+ if not os.path.exists(file_path):
160
+ return f"Error: Python file not found at {file_path}"
161
+ with open(file_path, "r", encoding="utf-8") as file:
162
+ content = file.read()
163
+ return content
164
+ except Exception as e:
165
+ return f"Error reading Python file: {str(e)}"
166
+
167
+ from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
168
+
169
+ class RetryDuckDuckGoSearchTool(DuckDuckGoSearchTool):
170
+ @retry(
171
+ stop=stop_after_attempt(3),
172
+ wait=wait_exponential(multiplier=1, min=4, max=10),
173
+ retry=retry_if_exception_type(Exception)
174
+ )
175
+
176
+ def forward(self, query: str) -> str:
177
+ return super().forward(query)
178
+
179
+ import cv2
180
+ import numpy as np
181
+ import os
182
+ from smolagents import Tool
183
+
184
+ class ChessboardToFENTool(Tool):
185
+ name = "chessboard_to_fen"
186
+ description = "Converts a PNG image of a chessboard to a FEN string describing the position."
187
+ inputs = {'image_path': {'type': 'string', 'description': 'Path to the PNG image of the chessboard.'}}
188
+ output_type = "string"
189
+
190
+ def __init__(self, template_dir='templates'):
191
+ self.template_dir = template_dir
192
+ self.templates = {}
193
+ for filename in os.listdir(template_dir):
194
+ if filename.endswith('.png'):
195
+ piece_name = filename.replace('.png', '')
196
+ self.templates[piece_name] = cv2.imread(os.path.join(template_dir, filename), 0)
197
+
198
+ self.piece_map = {
199
+ 'white_pawn': 'P', 'white_knight': 'N', 'white_bishop': 'B',
200
+ 'white_rook': 'R', 'white_queen': 'Q', 'white_king': 'K',
201
+ 'black_pawn': 'p', 'black_knight': 'n', 'black_bishop': 'b',
202
+ 'black_rook': 'r', 'black_queen': 'q', 'black_king': 'k',
203
+ 'empty': '1'
204
+ }
205
+
206
+ def forward(self, image_path: str) -> str:
207
+ img = cv2.imread(image_path)
208
+ if img is None:
209
+ return "Error: Image not found."
210
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
211
+ ret, corners = cv2.findChessboardCorners(gray, (7,7), None)
212
+ if not ret:
213
+ return "Error: Chessboard not detected."
214
+
215
+ # Define target grid for perspective transform
216
+ square_size = 100 # Arbitrary size for the squares
217
+ target_corners = np.array([[j*square_size, i*square_size] for i in range(7) for j in range(7)], dtype=np.float32)
218
+
219
+ # Compute homography and warp the image
220
+ h, status = cv2.findHomography(corners, target_corners)
221
+ warped = cv2.warpPerspective(gray, h, (8*square_size, 8*square_size))
222
+
223
+ # Determine square size
224
+ square_size = warped.shape[0] // 8
225
+
226
+ # Resize templates to match square size
227
+ resized_templates = {}
228
+ for name, temp in self.templates.items():
229
+ resized_templates[name] = cv2.resize(temp, (square_size, square_size))
230
+
231
+ # Classify each square and build FEN
232
+ fen_rows = []
233
+ for i in range(8):
234
+ row = ''
235
+ empty_count = 0
236
+ for j in range(8):
237
+ # Extract square
238
+ square_img = warped[i*square_size:(i+1)*square_size, j*square_size:(j+1)*square_size]
239
+ # Determine if light or dark square
240
+ is_dark = (i + j) % 2 == 0
241
+ relevant_templates = [name for name in resized_templates if ('dark' if is_dark else 'light') in name]
242
+
243
+ best_score = -1
244
+ best_piece = None
245
+ for temp_name in relevant_templates:
246
+ temp = resized_templates[temp_name]
247
+ res = cv2.matchTemplate(square_img, temp, cv2.TM_CCOEFF_NORMED)
248
+ min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(res)
249
+ if max_val > best_score:
250
+ best_score = max_val
251
+ best_piece = temp_name.split('_')[0] + '_' + temp_name.split('_')[1] # e.g., 'white_pawn'
252
+
253
+ if best_score > 0.8: # Threshold for match confidence
254
+ if 'empty' in best_piece:
255
+ empty_count += 1
256
+ else:
257
+ if empty_count > 0:
258
+ row += str(empty_count)
259
+ empty_count = 0
260
+ row += self.piece_map[best_piece]
261
+ else:
262
+ empty_count += 1
263
+
264
+ if empty_count > 0:
265
+ row += str(empty_count)
266
+ fen_rows.append(row)
267
+
268
+ fen = '/'.join(fen_rows)
269
+ return fen
270
+
271
+
272
+ ##############################
273
+ # MAG Agent
274
+ ##############################
275
+
276
+ class MagAgent:
277
+ def __init__(self, rate_limiter: Optional[Limiter] = None):
278
+ """Initialize the MagAgent with search tools."""
279
+ logger.info("Initializing MagAgent")
280
+ self.rate_limiter = rate_limiter
281
+
282
+ print("Initializing MagAgent with search tools...")
283
+ model = LiteLLMModel(
284
+ model_id="gemini/gemini-2.0-flash",
285
+ api_key=os.environ.get("GEMINI_KEY"),
286
+ max_tokens=8192
287
+ )
288
+
289
+ self.imports = [
290
+ "pandas",
291
+ "numpy",
292
+ "os",
293
+ "requests",
294
+ "tempfile",
295
+ "datetime",
296
+ "json",
297
+ "time",
298
+ "re",
299
+ "openpyxl",
300
+ "pathlib",
301
+ "sys",
302
+ "bs4",
303
+ "arxiv",
304
+ "whisper",
305
+ ]
306
+
307
+ self.tools = [
308
+ # RetryDuckDuckGoSearchTool(),
309
+ # WikipediaSearchTool(),
310
+ SpeechToTextTool(),
311
+ ExcelReaderTool(),
312
+ # VisitWebpageTool(),
313
+ PythonCodeReaderTool(),
314
+ search_arxiv,
315
+ ChessboardToFENTool(),
316
+ ]
317
+
318
+ self.prompt_template = (
319
+ """
320
+ You are an advanced AI assistant specialized in solving complex, real-world tasks, requiring multi-step reasoning, factual accuracy, and use of external tools.
321
+
322
+ Follow these principles:
323
+ - Reason step-by-step. Think through the solution logically and plan your actions carefully before answering.
324
+ - Validate information. Always verify facts when possible instead of guessing.
325
+ - When processing external data (e.g., YouTube transcripts, web searches), expect potential issues like missing punctuation, inconsistent formatting, or conversational text.
326
+ - When asked to transcript YouTube video, try searching it in www.youtubetotranscript.com.
327
+ - If the input is ambiguous, prioritize extracting key information relevant to the question.
328
+ - Use code if needed. For calculations, parsing, or transformations, generate Python code and execute it. Be cautious, as some questions contain time-consuming tasks, so analyze the question and choose the most efficient solution.
329
+ - Be precise and concise. The final answer must strictly match the required format with no extra commentary.
330
+ - Use tools intelligently. If a question involves external information, structured data, images, or audio, call the appropriate tool to retrieve or process it.
331
+ - If the question includes direct speech or quoted text (e.g., "Isn't that hot?"), treat it as a precise query and preserve the quoted structure in your response, including quotation marks for direct quotes (e.g., final_answer('"Extremely."')).
332
+ - If asked about the name of a place or city, use the full complete name without abbreviations (e.g., use Saint Petersburg instead of St.Petersburg).
333
+ - If asked to look up page numbers, make sure you don't mix them with problem or excercise numbers.
334
+ - If you cannot retrieve or process data (e.g., due to blocked requests), retry after 15 seconds delay, try another tool (try wikipedia_search, then web_search, then search_arxiv). Otherwise, return a clear error message: "Unable to retrieve data. Search has failed."
335
+ - Use `final_answer` to give the final answer.
336
+
337
+ QUESTION: {question}
338
+
339
+ {file_section}
340
+
341
+ ANSWER:
342
+ """
343
+ )
344
+
345
+ web_agent = ToolCallingAgent(
346
+ tools=[
347
+ # RetryDuckDuckGoSearchTool(),
348
+ # WikipediaSearchTool(),
349
+ # SpeechToTextTool(),
350
+ WebSearchTool(),
351
+ VisitWebpageTool(),
352
+ # ExcelReaderTool(),
353
+ # PythonCodeReaderTool(),
354
+ search_arxiv,
355
+ ],
356
+ model=model,
357
+ max_steps=15,
358
+ name="web_search_agent",
359
+ description="Runs web searches for you.",
360
+ )
361
+
362
+ self.agent = CodeAgent(
363
+ model=model,
364
+ managed_agents=[web_agent],
365
+ tools=self.tools,
366
+ add_base_tools=True,
367
+ additional_authorized_imports=self.imports,
368
+ verbosity_level=2,
369
+ max_steps=10
370
+ )
371
+ print("MagAgent initialized.")
372
+
373
+ async def __call__(self, question: str, file_path: Optional[str] = None) -> str:
374
+ """Process a question asynchronously using the MagAgent."""
375
+ print(f"MagAgent received question (first 50 chars): {question[:50]}... File path: {file_path}")
376
+ try:
377
+ if self.rate_limiter:
378
+ while not self.rate_limiter.consume(1):
379
+ print(f"Rate limit reached. Waiting...")
380
+ await asyncio.sleep(4)
381
+ # Conditionally include FILE: section only if file_path is provided
382
+ file_section = f"FILE: {file_path}" if file_path else ""
383
+ task = self.prompt_template.format(
384
+ question=question,
385
+ file_section=file_section
386
+ )
387
+ print(f"Calling agent.run...")
388
+ response = await asyncio.to_thread(self.agent.run, task=task)
389
+ print(f"Agent.run completed.")
390
+ response = str(response)
391
+ if not response:
392
+ print(f"No answer found.")
393
+ response = "No answer found."
394
+ print(f"MagAgent response: {response[:50]}...")
395
+ return response
396
+ except Exception as e:
397
+ error_msg = f"Error processing question: {str(e)}. Check API key or network connectivity."
398
+ print(error_msg)
399
+ return error_msg