|
""" |
|
Pytest configuration for agent testing. |
|
""" |
|
|
|
import os |
|
import pytest |
|
from typing import Dict, List, Optional |
|
|
|
from gagent.agents import registry, BaseAgent, OllamaAgent, GeminiAgent, OpenAIAgent |
|
|
|
|
|
@pytest.fixture |
|
def agent_factory(): |
|
""" |
|
Factory fixture to create agent instances with flexible configuration. |
|
|
|
Returns: |
|
Function that creates and returns an agent instance |
|
""" |
|
|
|
def _create_agent( |
|
agent_type: str, |
|
model_name: Optional[str] = None, |
|
api_key: Optional[str] = None, |
|
base_url: Optional[str] = None, |
|
**kwargs, |
|
) -> BaseAgent: |
|
""" |
|
Create an agent with the specified configuration. |
|
|
|
Args: |
|
agent_type: The type of agent to create |
|
model_name: The model name to use |
|
api_key: The API key to use |
|
base_url: The base URL to use |
|
**kwargs: Additional parameters for the agent |
|
|
|
Returns: |
|
An initialized agent instance |
|
""" |
|
|
|
env_model = os.environ.get(f"{agent_type.upper()}_MODEL", "qwen3" if agent_type == "ollama" else None) |
|
env_api_key = os.environ.get(f"{agent_type.upper()}_API_KEY", None) |
|
env_base_url = os.environ.get( |
|
f"{agent_type.upper()}_BASE_URL", "http://localhost:11434" if agent_type == "ollama" else None |
|
) |
|
|
|
return registry.get_agent( |
|
agent_type=agent_type, |
|
model_name=model_name or env_model, |
|
api_key=api_key or env_api_key, |
|
base_url=base_url or env_base_url, |
|
**kwargs, |
|
) |
|
|
|
return _create_agent |
|
|
|
|
|
@pytest.fixture |
|
def ollama_agent(agent_factory) -> OllamaAgent: |
|
"""Fixture to provide an Ollama agent.""" |
|
return agent_factory("ollama") |
|
|
|
|
|
@pytest.fixture |
|
def gemini_agent(agent_factory) -> GeminiAgent: |
|
"""Fixture to provide a Gemini agent if environment variables are set.""" |
|
api_key = os.environ.get("GOOGLE_API_KEY", None) |
|
if not api_key: |
|
pytest.skip("GOOGLE_API_KEY environment variable not set") |
|
return agent_factory("gemini") |
|
|
|
|
|
@pytest.fixture |
|
def openai_agent(agent_factory) -> OpenAIAgent: |
|
"""Fixture to provide an OpenAI agent if environment variables are set.""" |
|
api_key = os.environ.get("OPENAI_API_KEY", None) |
|
if not api_key: |
|
pytest.skip("OPENAI_API_KEY environment variable not set") |
|
return agent_factory("openai") |
|
|
|
|
|
@pytest.fixture |
|
def gaia_questions() -> List[Dict]: |
|
"""Load GAIA questions for testing.""" |
|
import json |
|
|
|
with open("exp/questions.json", "r") as f: |
|
return json.load(f) |
|
|
|
|
|
@pytest.fixture |
|
def gaia_validation_data() -> Dict: |
|
"""Load GAIA validation data.""" |
|
import json |
|
|
|
validation_data = {} |
|
|
|
with open("metadata.jsonl", "r") as f: |
|
for line in f: |
|
data = json.loads(line) |
|
validation_data[data["task_id"]] = data["Final answer"] |
|
|
|
return validation_data |
|
|