gagent / tests /agents /test_agents.py
uoc's picture
GAIA agent project.
a6998ef verified
import json
import os
import pytest
from pathlib import Path
import functools
from typing import Callable, Type, Any, Dict, Optional
from gagent.agents import BaseAgent, GeminiAgent
from tests.agents.fixtures import (
agent_factory,
ollama_agent,
gemini_agent,
openai_agent,
)
class TestAgents:
"""Test suite for agents with GAIA data."""
@staticmethod
def load_questions():
"""Load questions from questions.json file."""
with open("exp/questions.json", "r") as f:
return json.load(f)
@staticmethod
def load_validation_data():
"""Load validation data from GAIA dataset metadata."""
validation_data = {}
with open("metadata.jsonl", "r") as f:
for line in f:
data = json.loads(line)
validation_data[data["task_id"]] = data["Final answer"]
return validation_data
def _run_agent_test(self, agent: BaseAgent, num_questions: int = 2):
"""
Common test implementation for all agent types
Args:
agent: The agent to test
num_questions: Number of questions to test (default: 2)
Returns:
Tuple of (correct_count, total_tested)
"""
questions = self.load_questions()
validation_data = self.load_validation_data()
# Limit number of questions for testing
questions = questions[:num_questions]
# Keep track of correct answers
correct_count = 0
total_tested = 0
total_questions = len(questions)
for i, question_data in enumerate(questions):
task_id = question_data["task_id"]
if task_id not in validation_data:
continue
question = question_data["question"]
expected_answer = validation_data[task_id]
print(f"Testing question {i + 1}: {question[:50]}...")
# Call the agent with the question
response = agent.run(question, question_number=i + 1, total_questions=total_questions)
# Extract the final answer from the response
# Assuming the agent follows the format with "FINAL ANSWER: [answer]"
if "FINAL ANSWER:" in response:
answer = response.split("FINAL ANSWER:")[1].strip()
else:
answer = response.strip()
# Check if the answer is correct (exact match)
is_correct = answer == expected_answer
if is_correct:
correct_count += 1
total_tested += 1
print(f"Expected: {expected_answer}")
print(f"Got: {answer}")
print(f"Result: {'✓' if is_correct else '✗'}")
print("-" * 80)
# Compute accuracy
accuracy = correct_count / total_tested if total_tested > 0 else 0
print(f"Accuracy: {accuracy:.2%} ({correct_count}/{total_tested})")
return correct_count, total_tested
# def test_ollama_agent_with_gaia_data(self, ollama_agent: BaseAgent):
# """Test the Ollama agent with GAIA dataset questions and validate against ground truth."""
# correct_count, total_tested = self._run_agent_test(agent)
# # At least one correct answer required to pass the test
# assert correct_count > 0, "Agent should get at least one answer correct"
# def test_gemini_agent_with_gaia_data(self, gemini_agent: GeminiAgent):
# """Test the Gemini agent with the same GAIA test approach."""
# correct_count, total_tested = self._run_agent_test(gemini_agent, num_questions=2)
# # At least one correct answer required to pass the test
# assert correct_count > 0, "Agent should get at least one answer correct"
@pytest.mark.parametrize("agent_type,model_name", [("ollama", "phi4-mini")])
def test_ollama_with_different_model(self, agent_factory, agent_type, model_name):
"""Test Ollama agent with a different model."""
agent = agent_factory(agent_type=agent_type, model_name=model_name)
correct_count, total_tested = self._run_agent_test(agent, num_questions=3)
# Just verify it runs, not accuracy
assert correct_count > 0, "Should test at least one question"
# def test_ollama_with_different_model(self, ollama_agent: BaseAgent):
# """Test Ollama agent with a different model."""
# correct_count, total_tested = self._run_agent_test(ollama_agent, num_questions=3)
# # Just verify it runs, not accuracy
# assert correct_count > 0, "Should test at least one question"
# Can be uncommented when OpenAI API key is available
# def test_openai_agent_with_gaia_data(self, openai_agent: BaseAgent):
# """Test the OpenAI agent with the same GAIA test approach."""
# correct_count, total_tested = self._run_agent_test(agent)
# assert correct_count > 0, "Agent should get at least one answer correct"