|
import json |
|
import os |
|
import pytest |
|
from pathlib import Path |
|
import functools |
|
from typing import Callable, Type, Any, Dict, Optional |
|
|
|
from gagent.agents import BaseAgent, GeminiAgent |
|
from tests.agents.fixtures import ( |
|
agent_factory, |
|
ollama_agent, |
|
gemini_agent, |
|
openai_agent, |
|
) |
|
|
|
|
|
class TestAgents: |
|
"""Test suite for agents with GAIA data.""" |
|
|
|
@staticmethod |
|
def load_questions(): |
|
"""Load questions from questions.json file.""" |
|
with open("exp/questions.json", "r") as f: |
|
return json.load(f) |
|
|
|
@staticmethod |
|
def load_validation_data(): |
|
"""Load validation data from GAIA dataset metadata.""" |
|
validation_data = {} |
|
|
|
with open("metadata.jsonl", "r") as f: |
|
for line in f: |
|
data = json.loads(line) |
|
validation_data[data["task_id"]] = data["Final answer"] |
|
|
|
return validation_data |
|
|
|
def _run_agent_test(self, agent: BaseAgent, num_questions: int = 2): |
|
""" |
|
Common test implementation for all agent types |
|
|
|
Args: |
|
agent: The agent to test |
|
num_questions: Number of questions to test (default: 2) |
|
|
|
Returns: |
|
Tuple of (correct_count, total_tested) |
|
""" |
|
questions = self.load_questions() |
|
validation_data = self.load_validation_data() |
|
|
|
|
|
questions = questions[:num_questions] |
|
|
|
|
|
correct_count = 0 |
|
total_tested = 0 |
|
total_questions = len(questions) |
|
for i, question_data in enumerate(questions): |
|
task_id = question_data["task_id"] |
|
if task_id not in validation_data: |
|
continue |
|
|
|
question = question_data["question"] |
|
expected_answer = validation_data[task_id] |
|
|
|
print(f"Testing question {i + 1}: {question[:50]}...") |
|
|
|
|
|
response = agent.run(question, question_number=i + 1, total_questions=total_questions) |
|
|
|
|
|
|
|
if "FINAL ANSWER:" in response: |
|
answer = response.split("FINAL ANSWER:")[1].strip() |
|
else: |
|
answer = response.strip() |
|
|
|
|
|
is_correct = answer == expected_answer |
|
if is_correct: |
|
correct_count += 1 |
|
|
|
total_tested += 1 |
|
|
|
print(f"Expected: {expected_answer}") |
|
print(f"Got: {answer}") |
|
print(f"Result: {'✓' if is_correct else '✗'}") |
|
print("-" * 80) |
|
|
|
|
|
accuracy = correct_count / total_tested if total_tested > 0 else 0 |
|
print(f"Accuracy: {accuracy:.2%} ({correct_count}/{total_tested})") |
|
|
|
return correct_count, total_tested |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("agent_type,model_name", [("ollama", "phi4-mini")]) |
|
def test_ollama_with_different_model(self, agent_factory, agent_type, model_name): |
|
"""Test Ollama agent with a different model.""" |
|
agent = agent_factory(agent_type=agent_type, model_name=model_name) |
|
correct_count, total_tested = self._run_agent_test(agent, num_questions=3) |
|
|
|
|
|
assert correct_count > 0, "Should test at least one question" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|