import json import os import pytest from pathlib import Path import functools from typing import Callable, Type, Any, Dict, Optional from gagent.agents import BaseAgent, GeminiAgent from tests.agents.fixtures import ( agent_factory, ollama_agent, gemini_agent, openai_agent, ) class TestAgents: """Test suite for agents with GAIA data.""" @staticmethod def load_questions(): """Load questions from questions.json file.""" with open("exp/questions.json", "r") as f: return json.load(f) @staticmethod def load_validation_data(): """Load validation data from GAIA dataset metadata.""" validation_data = {} with open("metadata.jsonl", "r") as f: for line in f: data = json.loads(line) validation_data[data["task_id"]] = data["Final answer"] return validation_data def _run_agent_test(self, agent: BaseAgent, num_questions: int = 2): """ Common test implementation for all agent types Args: agent: The agent to test num_questions: Number of questions to test (default: 2) Returns: Tuple of (correct_count, total_tested) """ questions = self.load_questions() validation_data = self.load_validation_data() # Limit number of questions for testing questions = questions[:num_questions] # Keep track of correct answers correct_count = 0 total_tested = 0 total_questions = len(questions) for i, question_data in enumerate(questions): task_id = question_data["task_id"] if task_id not in validation_data: continue question = question_data["question"] expected_answer = validation_data[task_id] print(f"Testing question {i + 1}: {question[:50]}...") # Call the agent with the question response = agent.run(question, question_number=i + 1, total_questions=total_questions) # Extract the final answer from the response # Assuming the agent follows the format with "FINAL ANSWER: [answer]" if "FINAL ANSWER:" in response: answer = response.split("FINAL ANSWER:")[1].strip() else: answer = response.strip() # Check if the answer is correct (exact match) is_correct = answer == expected_answer if is_correct: correct_count += 1 total_tested += 1 print(f"Expected: {expected_answer}") print(f"Got: {answer}") print(f"Result: {'✓' if is_correct else '✗'}") print("-" * 80) # Compute accuracy accuracy = correct_count / total_tested if total_tested > 0 else 0 print(f"Accuracy: {accuracy:.2%} ({correct_count}/{total_tested})") return correct_count, total_tested # def test_ollama_agent_with_gaia_data(self, ollama_agent: BaseAgent): # """Test the Ollama agent with GAIA dataset questions and validate against ground truth.""" # correct_count, total_tested = self._run_agent_test(agent) # # At least one correct answer required to pass the test # assert correct_count > 0, "Agent should get at least one answer correct" # def test_gemini_agent_with_gaia_data(self, gemini_agent: GeminiAgent): # """Test the Gemini agent with the same GAIA test approach.""" # correct_count, total_tested = self._run_agent_test(gemini_agent, num_questions=2) # # At least one correct answer required to pass the test # assert correct_count > 0, "Agent should get at least one answer correct" @pytest.mark.parametrize("agent_type,model_name", [("ollama", "phi4-mini")]) def test_ollama_with_different_model(self, agent_factory, agent_type, model_name): """Test Ollama agent with a different model.""" agent = agent_factory(agent_type=agent_type, model_name=model_name) correct_count, total_tested = self._run_agent_test(agent, num_questions=3) # Just verify it runs, not accuracy assert correct_count > 0, "Should test at least one question" # def test_ollama_with_different_model(self, ollama_agent: BaseAgent): # """Test Ollama agent with a different model.""" # correct_count, total_tested = self._run_agent_test(ollama_agent, num_questions=3) # # Just verify it runs, not accuracy # assert correct_count > 0, "Should test at least one question" # Can be uncommented when OpenAI API key is available # def test_openai_agent_with_gaia_data(self, openai_agent: BaseAgent): # """Test the OpenAI agent with the same GAIA test approach.""" # correct_count, total_tested = self._run_agent_test(agent) # assert correct_count > 0, "Agent should get at least one answer correct"