real-jiakai commited on
Commit
06293e9
Β·
verified Β·
1 Parent(s): 66a0e23

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +168 -565
agent.py CHANGED
@@ -1,10 +1,7 @@
1
  """
2
- agent.py – Simplified Claude implementation for GAIA challenge
3
  -----------------------------------------------------------
4
- Environment
5
- -----------
6
- ANTHROPIC_API_KEY – API key from Anthropic (set in Hugging Face space secrets)
7
- GAIA_API_URL – (optional) override for the GAIA scoring endpoint
8
  """
9
 
10
  import base64
@@ -13,10 +10,11 @@ import os
13
  import re
14
  import tempfile
15
  import time
16
- from typing import List, Dict, Any, Optional
17
  import random
 
18
  import requests
19
  from urllib.parse import urlparse
 
20
  from smolagents import CodeAgent, DuckDuckGoSearchTool, PythonInterpreterTool, tool
21
 
22
  # --------------------------------------------------------------------------- #
@@ -35,156 +33,93 @@ def _download_file(file_id: str) -> bytes:
35
  return resp.content
36
 
37
  # --------------------------------------------------------------------------- #
38
- # Rate limiting helper
39
- # --------------------------------------------------------------------------- #
40
- class RateLimiter:
41
- """Simple rate limiter to prevent Anthropic API rate limit errors"""
42
- def __init__(self, requests_per_minute=20, burst=3):
43
- self.requests_per_minute = requests_per_minute
44
- self.burst = burst
45
- self.request_times = []
46
-
47
- def wait(self):
48
- """Wait if needed to avoid exceeding rate limits"""
49
- now = time.time()
50
- # Remove timestamps older than 1 minute
51
- self.request_times = [t for t in self.request_times if now - t < 60]
52
-
53
- # If we've made too many requests in the last minute, wait
54
- if len(self.request_times) >= self.requests_per_minute:
55
- oldest = min(self.request_times)
56
- sleep_time = 60 - (now - oldest) + 1 # +1 for safety
57
- print(f"Rate limit approaching. Waiting {sleep_time:.2f} seconds before next request...")
58
- time.sleep(sleep_time)
59
-
60
- # Add current timestamp to the list
61
- self.request_times.append(time.time())
62
-
63
- # Add a small random delay to avoid bursts of requests
64
- if len(self.request_times) > self.burst:
65
- time.sleep(random.uniform(0.2, 1.0))
66
-
67
- # Global rate limiter instance
68
- RATE_LIMITER = RateLimiter(requests_per_minute=15) # Reduced to be extra cautious
69
-
70
- # --------------------------------------------------------------------------- #
71
- # Direct function to call Claude via LiteLLM
72
- # --------------------------------------------------------------------------- #
73
- def call_claude(
74
- prompt: str,
75
- system_prompt: Optional[str] = None,
76
- temperature: float = 0.1,
77
- max_tokens: int = 1024,
78
- model_name: str = "anthropic/claude-3-5-sonnet-20240620"
79
- ) -> str:
80
- """
81
- Call Claude through LiteLLM directly, following official LiteLLM documentation
82
-
83
- Args:
84
- prompt: The user's question
85
- system_prompt: Optional system prompt
86
- temperature: Temperature for generation
87
- max_tokens: Max tokens to generate
88
- model_name: Claude model to use
89
-
90
- Returns:
91
- The response text from Claude
92
- """
93
- from litellm import completion
94
-
95
- # Respect rate limits
96
- RATE_LIMITER.wait()
97
-
98
- try:
99
- # Build messages following exactly LiteLLM's documented format
100
- messages = []
101
-
102
- # Add system message if provided
103
- if system_prompt:
104
- messages.append({"role": "system", "content": system_prompt})
105
-
106
- # Add user message - this is simple text only format
107
- messages.append({"role": "user", "content": prompt})
108
-
109
- # Make the API call exactly as documented
110
- response = completion(
111
- model=model_name,
112
- messages=messages,
113
- temperature=temperature,
114
- max_tokens=max_tokens
115
- )
116
-
117
- # Extract just the text content from the response
118
- return response.choices[0].message.content
119
-
120
- except Exception as e:
121
- if "rate_limit" in str(e).lower():
122
- print(f"Rate limit hit: {e}")
123
- # Wait 60 seconds and try again
124
- time.sleep(60)
125
- return call_claude(prompt, system_prompt, temperature, max_tokens, model_name)
126
- else:
127
- print(f"Error calling Claude API: {e}")
128
- raise
129
-
130
- # --------------------------------------------------------------------------- #
131
- # Simple Claude Model wrapper for smolagents
132
  # --------------------------------------------------------------------------- #
133
- class SimpleClaudeModel:
134
  """
135
- A minimal wrapper around LiteLLM's direct call to Anthropic that works with smolagents
 
136
  """
137
 
138
  def __init__(
139
- self,
140
- model_id: str = "anthropic/claude-3-5-sonnet-20240620",
141
  api_key: Optional[str] = None,
142
- temperature: float = 0.1,
143
- max_tokens: int = 1024,
144
- system_prompt: Optional[str] = None,
145
  ):
146
- """Initialize a minimal Claude model wrapper"""
147
- # Get API key from env if not provided
148
- if api_key is None:
149
- api_key = os.getenv("ANTHROPIC_API_KEY")
150
- if not api_key:
151
- raise ValueError("No Anthropic API key provided. Set ANTHROPIC_API_KEY env var.")
152
-
153
- self.model_id = model_id
154
- self.api_key = api_key
155
  self.temperature = temperature
156
- self.max_tokens = max_tokens
157
 
158
- # Store the system prompt
159
- self.system_prompt = system_prompt or """You are a concise, highly accurate assistant specialized in solving challenges.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  Your answers should be precise, direct, and exactly match the expected format.
161
  All answers are graded by exact string match, so format carefully!"""
 
 
 
 
 
 
162
 
163
- print(f"Initialized SimpleClaudeModel with {model_id}")
164
-
165
- def __call__(self, prompt: str, **kwargs) -> str:
166
- """Call method to make this class callable by smolagents CodeAgent"""
167
- # Directly use the call_claude function
168
- return call_claude(
169
- prompt=prompt,
170
- system_prompt=self.system_prompt,
171
- temperature=self.temperature,
172
- max_tokens=self.max_tokens,
173
- model_name=self.model_id
174
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
  # --------------------------------------------------------------------------- #
177
- # custom tool: fetch GAIA attachments
178
  # --------------------------------------------------------------------------- #
179
  @tool
180
  def gaia_file_reader(file_id: str) -> str:
181
  """
182
- Download a GAIA attachment and return its contents.
183
- Args:
184
- file_id: identifier that appears inside a <file:...> placeholder.
185
- Returns:
186
- base64-encoded string for binary files (images, PDFs, …) or decoded
187
- UTF-8 text for textual files.
188
  """
189
  try:
190
  raw = _download_file(file_id)
@@ -196,21 +131,11 @@ def gaia_file_reader(file_id: str) -> str:
196
  return f"ERROR downloading {file_id}: {exc}"
197
 
198
  # --------------------------------------------------------------------------- #
199
- # additional tool functions
200
  # --------------------------------------------------------------------------- #
201
  @tool
202
  def save_and_read_file(content: str, filename: Optional[str] = None) -> str:
203
- """
204
- Save content to a temporary file and return the path.
205
- Useful for processing files from the GAIA API.
206
-
207
- Args:
208
- content: The content to save to the file
209
- filename: Optional filename, will generate a random name if not provided
210
-
211
- Returns:
212
- Path to the saved file
213
- """
214
  temp_dir = tempfile.gettempdir()
215
  if filename is None:
216
  temp_file = tempfile.NamedTemporaryFile(delete=False)
@@ -218,495 +143,173 @@ def save_and_read_file(content: str, filename: Optional[str] = None) -> str:
218
  else:
219
  filepath = os.path.join(temp_dir, filename)
220
 
221
- # Write content to the file
222
  with open(filepath, 'w') as f:
223
  f.write(content)
224
 
225
- return f"File saved to {filepath}. You can read this file to process its contents."
226
-
227
- @tool
228
- def download_file_from_url(url: str, filename: Optional[str] = None) -> str:
229
- """
230
- Download a file from a URL and save it to a temporary location.
231
-
232
- Args:
233
- url: The URL to download from
234
- filename: Optional filename, will generate one based on URL if not provided
235
-
236
- Returns:
237
- Path to the downloaded file
238
- """
239
- try:
240
- # Parse URL to get filename if not provided
241
- if not filename:
242
- path = urlparse(url).path
243
- filename = os.path.basename(path)
244
- if not filename:
245
- # Generate a random name if we couldn't extract one
246
- import uuid
247
- filename = f"downloaded_{uuid.uuid4().hex[:8]}"
248
-
249
- # Create temporary file
250
- temp_dir = tempfile.gettempdir()
251
- filepath = os.path.join(temp_dir, filename)
252
-
253
- # Download the file
254
- response = requests.get(url, stream=True)
255
- response.raise_for_status()
256
-
257
- # Save the file
258
- with open(filepath, 'wb') as f:
259
- for chunk in response.iter_content(chunk_size=8192):
260
- f.write(chunk)
261
-
262
- return f"File downloaded to {filepath}. You can now process this file."
263
- except Exception as e:
264
- return f"Error downloading file: {str(e)}"
265
-
266
- @tool
267
- def extract_text_from_image(image_path: str) -> str:
268
- """
269
- Extract text from an image using pytesseract (if available).
270
-
271
- Args:
272
- image_path: Path to the image file
273
-
274
- Returns:
275
- Extracted text or error message
276
- """
277
- try:
278
- # Try to import pytesseract
279
- import pytesseract
280
- from PIL import Image
281
-
282
- # Open the image
283
- image = Image.open(image_path)
284
-
285
- # Extract text
286
- text = pytesseract.image_to_string(image)
287
-
288
- return f"Extracted text from image:\n\n{text}"
289
- except ImportError:
290
- return "Error: pytesseract is not installed. Please install it with 'pip install pytesseract' and ensure Tesseract OCR is installed on your system."
291
- except Exception as e:
292
- return f"Error extracting text from image: {str(e)}"
293
 
294
  @tool
295
  def analyze_csv_file(file_path: str, query: str) -> str:
296
- """
297
- Analyze a CSV file using pandas and answer a question about it.
298
-
299
- Args:
300
- file_path: Path to the CSV file
301
- query: Question about the data
302
-
303
- Returns:
304
- Analysis result or error message
305
- """
306
  try:
307
  import pandas as pd
308
-
309
- # Read the CSV file
310
  df = pd.read_csv(file_path)
311
 
312
- # Run various analyses based on the query
313
  result = f"CSV file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
314
  result += f"Columns: {', '.join(df.columns)}\n\n"
315
-
316
- # Add summary statistics
317
  result += "Summary statistics:\n"
318
  result += str(df.describe())
319
 
320
  return result
321
  except ImportError:
322
- return "Error: pandas is not installed. Please install it with 'pip install pandas'."
323
  except Exception as e:
324
  return f"Error analyzing CSV file: {str(e)}"
325
 
326
  @tool
327
  def analyze_excel_file(file_path: str, query: str) -> str:
328
- """
329
- Analyze an Excel file using pandas and answer a question about it.
330
-
331
- Args:
332
- file_path: Path to the Excel file
333
- query: Question about the data
334
-
335
- Returns:
336
- Analysis result or error message
337
- """
338
  try:
339
  import pandas as pd
340
-
341
- # Read the Excel file
342
  df = pd.read_excel(file_path)
343
 
344
- # Run various analyses based on the query
345
  result = f"Excel file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
346
  result += f"Columns: {', '.join(df.columns)}\n\n"
347
-
348
- # Add summary statistics
349
  result += "Summary statistics:\n"
350
  result += str(df.describe())
351
 
352
  return result
353
  except ImportError:
354
- return "Error: pandas and openpyxl are not installed. Please install them with 'pip install pandas openpyxl'."
355
  except Exception as e:
356
  return f"Error analyzing Excel file: {str(e)}"
357
 
358
  # --------------------------------------------------------------------------- #
359
- # GAIAAgent class
360
- # --------------------------------------------------------------------------- #
361
- class GAIAAgent:
362
- def __init__(
363
- self,
364
- api_key: Optional[str] = None,
365
- temperature: float = 0.1,
366
- verbose: bool = False,
367
- max_tokens: int = 1024,
368
- ):
369
- """
370
- Initialize a GAIAAgent with Claude model
371
-
372
- Args:
373
- api_key: Anthropic API key (fetched from environment if not provided)
374
- temperature: Temperature for text generation
375
- verbose: Enable verbose logging
376
- max_tokens: Maximum number of tokens to generate per response
377
- """
378
- # Set verbosity
379
- self.verbose = verbose
380
-
381
- # System prompt for all Claude interactions
382
- self.system_prompt = """You are a concise, highly accurate assistant specialized in solving challenges for the GAIA benchmark.
383
- Unless explicitly required, reply with ONE short sentence.
384
- Your answers should be precise, direct, and exactly match the expected format.
385
- All answers are graded by exact string match, so format carefully!"""
386
-
387
- # Get API key
388
- if api_key is None:
389
- api_key = os.getenv("ANTHROPIC_API_KEY")
390
- if not api_key:
391
- raise ValueError("No Anthropic token provided. Please set ANTHROPIC_API_KEY environment variable.")
392
-
393
- if self.verbose:
394
- print(f"Using Anthropic token: {api_key[:5]}...")
395
-
396
- # Initialize Claude model with our simplified wrapper
397
- self.model = SimpleClaudeModel(
398
- model_id="anthropic/claude-3-5-sonnet-20240620", # Use Claude 3.5 Sonnet
399
- api_key=api_key,
400
- temperature=temperature,
401
- max_tokens=max_tokens,
402
- system_prompt=self.system_prompt,
403
- )
404
-
405
- if self.verbose:
406
- print(f"Initialized model: SimpleClaudeModel - claude-3-5-sonnet-20240620")
407
-
408
- # Initialize default tools
409
- self.tools = [
410
- DuckDuckGoSearchTool(),
411
- PythonInterpreterTool(),
412
- save_and_read_file,
413
- download_file_from_url,
414
- analyze_csv_file,
415
- analyze_excel_file,
416
- gaia_file_reader
417
- ]
418
-
419
- # Add extract_text_from_image if PIL and pytesseract are available
420
- try:
421
- import pytesseract
422
- from PIL import Image
423
- self.tools.append(extract_text_from_image)
424
- if self.verbose:
425
- print("Added image processing tool")
426
- except ImportError:
427
- if self.verbose:
428
- print("Image processing libraries not available")
429
-
430
- if self.verbose:
431
- print(f"Initialized with {len(self.tools)} tools")
432
-
433
- # Setup imports allowed
434
- self.imports = ["pandas", "numpy", "datetime", "json", "re", "math", "os", "requests", "csv", "urllib"]
435
-
436
- # Initialize the CodeAgent
437
- self.agent = CodeAgent(
438
- tools=self.tools,
439
- model=self.model,
440
- additional_authorized_imports=self.imports,
441
- executor_type="local",
442
- verbosity_level=2 if self.verbose else 0
443
- )
444
-
445
- if self.verbose:
446
- print("Agent initialized and ready")
447
-
448
- def answer_question(self, question: str, task_file_path: Optional[str] = None) -> str:
449
- """
450
- Process a GAIA benchmark question and return the answer
451
-
452
- Args:
453
- question: The question to answer
454
- task_file_path: Optional path to a file associated with the question
455
-
456
- Returns:
457
- The answer to the question
458
- """
459
- try:
460
- if self.verbose:
461
- print(f"Processing question: {question}")
462
- if task_file_path:
463
- print(f"With associated file: {task_file_path}")
464
-
465
- # Create a context with file information if available
466
- context = question
467
- file_content = None
468
-
469
- # If there's a file, read it and include its content in the context
470
- if task_file_path:
471
- try:
472
- # Limit file content size to avoid token limits
473
- max_file_size = 8000 # Characters - reduced further to help with token limits
474
- with open(task_file_path, 'r', errors='ignore') as f:
475
- file_content = f.read(max_file_size)
476
- if len(file_content) >= max_file_size:
477
- file_content = file_content[:max_file_size] + "... [content truncated to prevent exceeding token limits]"
478
-
479
- # Determine file type from extension
480
- import os
481
- file_ext = os.path.splitext(task_file_path)[1].lower()
482
-
483
- context = f"""
484
- Question: {question}
485
- This question has an associated file. Here is the file content (it may be truncated):
486
- ```{file_ext}
487
- {file_content}
488
- ```
489
- Analyze the available file content to answer the question.
490
- """
491
- except Exception as file_e:
492
- try:
493
- # Try to read in binary mode
494
- with open(task_file_path, 'rb') as f:
495
- binary_content = f.read()
496
-
497
- # For image files
498
- if file_ext.lower() in ['.jpg', '.jpeg', '.png', '.gif', '.bmp']:
499
- context = f"""
500
- Question: {question}
501
- This question has an associated image file. Please use the extract_text_from_image tool to process it.
502
- File path: {task_file_path}
503
- """
504
- else:
505
- context = f"""
506
- Question: {question}
507
- This question has an associated file at path: {task_file_path}
508
- This is a binary file. Use appropriate tools to analyze it.
509
- """
510
- except Exception as binary_e:
511
- context = f"""
512
- Question: {question}
513
- This question has an associated file at path: {task_file_path}
514
- However, there was an error reading the file: {file_e}
515
- You can still try to answer the question based on the information provided.
516
- """
517
-
518
- # Check for special cases that need specific formatting
519
- # Reversed text questions
520
- if question.startswith(".") or ".rewsna eht sa" in question:
521
- context = f"""
522
- This question appears to be in reversed text. Here's the reversed version:
523
- {question[::-1]}
524
- Now answer the question above. Remember to format your answer exactly as requested.
525
- """
526
-
527
- # Add a prompt to ensure precise answers but keep it concise
528
- full_prompt = f"""{context}
529
- When answering, provide ONLY the precise answer requested.
530
- Do not include explanations, steps, reasoning, or additional text.
531
- Be direct and specific. GAIA benchmark requires exact matching answers.
532
- Example: If asked "What is the capital of France?", respond just with "Paris".
533
- """
534
-
535
- # Run the agent with the question
536
- answer = self.agent.run(full_prompt)
537
-
538
- # Clean up the answer to ensure it's in the expected format
539
- # Remove common prefixes that models often add
540
- answer = self._clean_answer(answer)
541
-
542
- if self.verbose:
543
- print(f"Generated answer: {answer}")
544
-
545
- return answer
546
- except Exception as e:
547
- error_msg = f"Error answering question: {e}"
548
- if self.verbose:
549
- print(error_msg)
550
- return error_msg
551
-
552
- def _clean_answer(self, answer: any) -> str:
553
- """
554
- Clean up the answer to remove common prefixes and formatting
555
- that models often add but that can cause exact match failures.
556
-
557
- Args:
558
- answer: The raw answer from the model
559
-
560
- Returns:
561
- The cleaned answer as a string
562
- """
563
- # Convert non-string types to strings
564
- if not isinstance(answer, str):
565
- # Handle numeric types (float, int)
566
- if isinstance(answer, float):
567
- # Format floating point numbers properly
568
- # Check if it's an integer value in float form (e.g., 12.0)
569
- if answer.is_integer():
570
- formatted_answer = str(int(answer))
571
- else:
572
- # For currency values that might need formatting
573
- if abs(answer) >= 1000:
574
- formatted_answer = f"${answer:,.2f}"
575
- else:
576
- formatted_answer = str(answer)
577
- return formatted_answer
578
- elif isinstance(answer, int):
579
- return str(answer)
580
- else:
581
- # For any other type
582
- return str(answer)
583
-
584
- # Now we know answer is a string, so we can safely use string methods
585
- # Normalize whitespace
586
- answer = answer.strip()
587
-
588
- # Remove common prefixes and formatting that models add
589
- prefixes_to_remove = [
590
- "The answer is ",
591
- "Answer: ",
592
- "Final answer: ",
593
- "The result is ",
594
- "To answer this question: ",
595
- "Based on the information provided, ",
596
- "According to the information: ",
597
- ]
598
-
599
- for prefix in prefixes_to_remove:
600
- if answer.startswith(prefix):
601
- answer = answer[len(prefix):].strip()
602
-
603
- # Remove quotes if they wrap the entire answer
604
- if (answer.startswith('"') and answer.endswith('"')) or (answer.startswith("'") and answer.endswith("'")):
605
- answer = answer[1:-1].strip()
606
-
607
- return answer
608
-
609
- # --------------------------------------------------------------------------- #
610
- # ClaudeAgent class - Wrapper around GAIAAgent
611
  # --------------------------------------------------------------------------- #
612
  class ClaudeAgent:
613
- """Claude-enhanced agent for GAIA challenge"""
614
 
615
  def __init__(self):
616
- # Try to initialize GAIAAgent with Claude
617
  try:
618
  # Get API key
619
  api_key = os.getenv("ANTHROPIC_API_KEY")
620
  if not api_key:
621
  raise ValueError("ANTHROPIC_API_KEY environment variable not found")
622
 
623
- print("βœ… Initializing GAIAAgent with Claude")
 
 
 
 
 
 
 
 
 
 
 
 
 
624
 
625
- # Create GAIAAgent instance
626
- self.agent = GAIAAgent(
627
- api_key=api_key,
628
- temperature=0.1, # Use low temperature for precise answers
629
- verbose=True, # Enable verbose logging
630
- max_tokens=1024, # Reduce max tokens to avoid hitting rate limits
 
631
  )
 
 
 
632
  except Exception as e:
633
- print(f"Error initializing GAIAAgent: {e}")
634
  raise
635
 
636
  def __call__(self, question: str) -> str:
637
- """
638
- Process a GAIA question and return the answer
639
-
640
- Args:
641
- question: The question to answer
642
-
643
- Returns:
644
- The answer to the question
645
- """
646
  try:
647
- print(f"Received question: {question[:100]}..." if len(question) > 100 else f"Received question: {question}")
648
-
649
- # Add delay between questions to respect rate limits
650
- time.sleep(random.uniform(0.5, 2.0))
651
 
652
- # Detect reversed text
653
- if question.startswith(".") or ".rewsna eht sa" in question:
654
- print("Detected reversed text question")
655
- # GAIAAgent handles reversed text internally
656
 
657
- # Detect if there's a file
658
  file_match = re.search(r"<file:([^>]+)>", question)
659
  if file_match:
660
  file_id = file_match.group(1)
661
- print(f"Detected file reference: {file_id}")
662
 
663
- # Download the file
664
  try:
665
  file_content = _download_file(file_id)
666
-
667
- # Create temporary file for the file
668
  temp_dir = tempfile.gettempdir()
669
  file_path = os.path.join(temp_dir, file_id)
670
 
671
- # Save file content
672
  with open(file_path, 'wb') as f:
673
  f.write(file_content)
674
 
675
- print(f"File downloaded to: {file_path}")
676
-
677
  # Remove file tag from question
678
  clean_question = re.sub(r"<file:[^>]+>", "", question).strip()
679
 
680
- # Process question with file path
681
- answer = self.agent.answer_question(clean_question, file_path)
682
- return self._clean_answer(answer)
 
 
 
 
683
  except Exception as e:
684
- print(f"Error processing file: {e}")
685
- # Fall back to processing without file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
686
 
687
- # Process standard question
688
- answer = self.agent.answer_question(question)
689
- return self._clean_answer(answer)
690
  except Exception as e:
691
- print(f"Error processing question: {e}")
692
- error_msg = f"Unable to process question: {str(e)}"
693
- return error_msg
694
 
695
- def _clean_answer(self, answer: str) -> str:
696
- """
697
- Final cleanup of answer to ensure correct format
698
- Reuses GAIAAgent's cleaning method
699
- """
700
- # Already cleaned in GAIAAgent, but do additional checks
701
- if isinstance(answer, str):
702
- # Remove any trailing periods and whitespace
703
- answer = answer.rstrip(". \t\n\r")
704
-
705
- # Ensure it's not too long an answer - GAIA usually needs concise responses
706
- if len(answer) > 1000:
707
- # Try to find the first sentence or statement of the answer
708
- sentences = answer.split('. ')
709
- if len(sentences) > 1:
710
- return sentences[0].strip()
711
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712
  return answer
 
1
  """
2
+ agent.py - Minimal Claude implementation for GAIA challenge
3
  -----------------------------------------------------------
4
+ A simplified implementation with direct litellm access to Anthropic's Claude
 
 
 
5
  """
6
 
7
  import base64
 
10
  import re
11
  import tempfile
12
  import time
 
13
  import random
14
+ from typing import List, Dict, Any, Optional
15
  import requests
16
  from urllib.parse import urlparse
17
+
18
  from smolagents import CodeAgent, DuckDuckGoSearchTool, PythonInterpreterTool, tool
19
 
20
  # --------------------------------------------------------------------------- #
 
33
  return resp.content
34
 
35
  # --------------------------------------------------------------------------- #
36
+ # Direct Claude model implementation with litellm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  # --------------------------------------------------------------------------- #
38
+ class DirectClaudeModel:
39
  """
40
+ Direct interface to Claude via litellm that works with smolagents
41
+ This avoids the message format issues by keeping things very simple
42
  """
43
 
44
  def __init__(
45
+ self,
 
46
  api_key: Optional[str] = None,
47
+ temperature: float = 0.1
 
 
48
  ):
49
+ """Initialize the Claude model"""
50
+ self.api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
51
+ if not self.api_key:
52
+ raise ValueError("No Anthropic API key provided")
53
+
 
 
 
 
54
  self.temperature = temperature
55
+ self.model_name = "anthropic/claude-3-5-sonnet-20240620"
56
 
57
+ print(f"Initialized DirectClaudeModel with {self.model_name}")
58
+
59
+ # Sleep random amount to avoid race conditions with many queries
60
+ time.sleep(random.uniform(1, 3))
61
+
62
+ def __call__(self, prompt: str, **kwargs) -> str:
63
+ """
64
+ Simple call method that works with smolagents
65
+
66
+ Args:
67
+ prompt: The user prompt
68
+ **kwargs: Additional parameters (ignored)
69
+
70
+ Returns:
71
+ Claude's response as a string
72
+ """
73
+ # Import here to avoid any circular imports
74
+ from litellm import completion
75
+
76
+ # Use a simple format: system message + user message
77
+ messages = [
78
+ {
79
+ "role": "system",
80
+ "content": """You are a concise, highly accurate assistant specialized in solving challenges.
81
  Your answers should be precise, direct, and exactly match the expected format.
82
  All answers are graded by exact string match, so format carefully!"""
83
+ },
84
+ {
85
+ "role": "user",
86
+ "content": prompt
87
+ }
88
+ ]
89
 
90
+ # Add delay to avoid rate limits
91
+ time.sleep(random.uniform(0.5, 2.0))
92
+
93
+ try:
94
+ # Make API call with simple format
95
+ response = completion(
96
+ model=self.model_name,
97
+ messages=messages,
98
+ temperature=self.temperature,
99
+ max_tokens=1024,
100
+ api_key=self.api_key
101
+ )
102
+
103
+ # Extract and return the text content only
104
+ return response.choices[0].message.content
105
+
106
+ except Exception as e:
107
+ # If it's a rate limit error, wait and retry
108
+ if "rate_limit" in str(e).lower():
109
+ print(f"Rate limit hit, waiting 30 seconds: {e}")
110
+ time.sleep(30)
111
+ return self.__call__(prompt, **kwargs)
112
+ else:
113
+ print(f"Error: {str(e)}")
114
+ raise
115
 
116
  # --------------------------------------------------------------------------- #
117
+ # Custom tool: fetch GAIA attachments
118
  # --------------------------------------------------------------------------- #
119
  @tool
120
  def gaia_file_reader(file_id: str) -> str:
121
  """
122
+ Download a GAIA attachment and return its contents
 
 
 
 
 
123
  """
124
  try:
125
  raw = _download_file(file_id)
 
131
  return f"ERROR downloading {file_id}: {exc}"
132
 
133
  # --------------------------------------------------------------------------- #
134
+ # Additional tools
135
  # --------------------------------------------------------------------------- #
136
  @tool
137
  def save_and_read_file(content: str, filename: Optional[str] = None) -> str:
138
+ """Save content to a file and return the path"""
 
 
 
 
 
 
 
 
 
 
139
  temp_dir = tempfile.gettempdir()
140
  if filename is None:
141
  temp_file = tempfile.NamedTemporaryFile(delete=False)
 
143
  else:
144
  filepath = os.path.join(temp_dir, filename)
145
 
 
146
  with open(filepath, 'w') as f:
147
  f.write(content)
148
 
149
+ return f"File saved to {filepath}."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
  @tool
152
  def analyze_csv_file(file_path: str, query: str) -> str:
153
+ """Analyze a CSV file with pandas"""
 
 
 
 
 
 
 
 
 
154
  try:
155
  import pandas as pd
 
 
156
  df = pd.read_csv(file_path)
157
 
 
158
  result = f"CSV file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
159
  result += f"Columns: {', '.join(df.columns)}\n\n"
 
 
160
  result += "Summary statistics:\n"
161
  result += str(df.describe())
162
 
163
  return result
164
  except ImportError:
165
+ return "Error: pandas is not installed."
166
  except Exception as e:
167
  return f"Error analyzing CSV file: {str(e)}"
168
 
169
  @tool
170
  def analyze_excel_file(file_path: str, query: str) -> str:
171
+ """Analyze an Excel file with pandas"""
 
 
 
 
 
 
 
 
 
172
  try:
173
  import pandas as pd
 
 
174
  df = pd.read_excel(file_path)
175
 
 
176
  result = f"Excel file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
177
  result += f"Columns: {', '.join(df.columns)}\n\n"
 
 
178
  result += "Summary statistics:\n"
179
  result += str(df.describe())
180
 
181
  return result
182
  except ImportError:
183
+ return "Error: pandas and openpyxl are not installed."
184
  except Exception as e:
185
  return f"Error analyzing Excel file: {str(e)}"
186
 
187
  # --------------------------------------------------------------------------- #
188
+ # ClaudeAgent - Main class for GAIA challenge
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  # --------------------------------------------------------------------------- #
190
  class ClaudeAgent:
191
+ """A simplified Claude agent for the GAIA challenge"""
192
 
193
  def __init__(self):
194
+ """Initialize the agent with Claude"""
195
  try:
196
  # Get API key
197
  api_key = os.getenv("ANTHROPIC_API_KEY")
198
  if not api_key:
199
  raise ValueError("ANTHROPIC_API_KEY environment variable not found")
200
 
201
+ print("βœ… Initializing ClaudeAgent")
202
+
203
+ # Create the model with direct implementation
204
+ model = DirectClaudeModel(api_key=api_key, temperature=0.1)
205
+
206
+ # Set up tools
207
+ tools = [
208
+ DuckDuckGoSearchTool(),
209
+ PythonInterpreterTool(),
210
+ save_and_read_file,
211
+ analyze_csv_file,
212
+ analyze_excel_file,
213
+ gaia_file_reader
214
+ ]
215
 
216
+ # Create the CodeAgent
217
+ self.agent = CodeAgent(
218
+ tools=tools,
219
+ model=model,
220
+ additional_authorized_imports=["pandas", "numpy", "json", "re", "math"],
221
+ executor_type="local",
222
+ verbosity_level=2
223
  )
224
+
225
+ print("Agent initialized successfully")
226
+
227
  except Exception as e:
228
+ print(f"Error initializing ClaudeAgent: {e}")
229
  raise
230
 
231
  def __call__(self, question: str) -> str:
232
+ """Process a question and return the answer"""
 
 
 
 
 
 
 
 
233
  try:
234
+ print(f"Processing question: {question[:100]}..." if len(question) > 100 else question)
 
 
 
235
 
236
+ # Add a small delay between questions
237
+ time.sleep(random.uniform(1.0, 3.0))
 
 
238
 
239
+ # Handle file references
240
  file_match = re.search(r"<file:([^>]+)>", question)
241
  if file_match:
242
  file_id = file_match.group(1)
243
+ print(f"Detected file: {file_id}")
244
 
245
+ # Download file
246
  try:
247
  file_content = _download_file(file_id)
 
 
248
  temp_dir = tempfile.gettempdir()
249
  file_path = os.path.join(temp_dir, file_id)
250
 
 
251
  with open(file_path, 'wb') as f:
252
  f.write(file_content)
253
 
 
 
254
  # Remove file tag from question
255
  clean_question = re.sub(r"<file:[^>]+>", "", question).strip()
256
 
257
+ # Build prompt with file context
258
+ prompt = f"""
259
+ Question: {clean_question}
260
+ There is a file available at path: {file_path}
261
+ Use appropriate tools to analyze this file if needed.
262
+ Answer the question directly and precisely.
263
+ """
264
  except Exception as e:
265
+ print(f"Error downloading file: {e}")
266
+ prompt = question
267
+ else:
268
+ # Handle reversed text separately
269
+ if question.startswith(".") or ".rewsna eht sa" in question:
270
+ prompt = f"""
271
+ This question is in reversed text. Here's the normal version:
272
+ {question[::-1]}
273
+ Answer the question directly and precisely.
274
+ """
275
+ else:
276
+ prompt = question
277
+
278
+ # Execute agent with prompt
279
+ answer = self.agent.run(prompt)
280
+
281
+ # Clean up response
282
+ answer = self._clean_answer(answer)
283
+
284
+ print(f"Generated answer: {answer}")
285
+ return answer
286
 
 
 
 
287
  except Exception as e:
288
+ print(f"Error: {str(e)}")
289
+ return f"Error processing question: {str(e)}"
 
290
 
291
+ def _clean_answer(self, answer: any) -> str:
292
+ """Clean up the answer for exact matching"""
293
+ if not isinstance(answer, str):
294
+ return str(answer)
 
 
 
 
 
 
 
 
 
 
 
 
295
 
296
+ # Normalize spacing
297
+ answer = answer.strip()
298
+
299
+ # Remove common prefixes
300
+ prefixes = [
301
+ "The answer is ", "Answer: ", "Final answer: ",
302
+ "The result is ", "Based on the information provided, "
303
+ ]
304
+
305
+ for prefix in prefixes:
306
+ if answer.startswith(prefix):
307
+ answer = answer[len(prefix):].strip()
308
+
309
+ # Remove quotes
310
+ if (answer.startswith('"') and answer.endswith('"')) or (
311
+ answer.startswith("'") and answer.endswith("'")
312
+ ):
313
+ answer = answer[1:-1].strip()
314
+
315
  return answer