|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
from abc import ABC, abstractmethod |
|
from pathlib import Path |
|
from typing import Any, Dict, List, Literal, Optional |
|
|
|
from camel.agents import ChatAgent |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class BaseBenchmark(ABC): |
|
r"""Base class for benchmarks. |
|
|
|
Attributes: |
|
name (str): Name of the benchmark. |
|
data_dir (str): Path to the data directory. |
|
save_to (str): Path to save the results. |
|
processes (int): Number of processes to use for parallel |
|
processing. :(default: :obj:`1`) |
|
""" |
|
|
|
def __init__( |
|
self, name: str, data_dir: str, save_to: str, processes: int = 1 |
|
): |
|
r"""Initialize the benchmark. |
|
|
|
Args: |
|
name (str): Name of the benchmark. |
|
data_dir (str): Path to the data directory. |
|
save_to (str): Path to save the results. |
|
processes (int): Number of processes to use for parallel |
|
processing. :(default: :obj:`1`) |
|
|
|
""" |
|
self.name = name |
|
self.data_dir = Path(data_dir) |
|
self.processes = processes |
|
self.save_to = save_to |
|
if not self.data_dir.exists(): |
|
logger.info( |
|
f"Data directory {data_dir} does not exist. Creating it." |
|
) |
|
self.data_dir.mkdir(parents=True, exist_ok=True) |
|
if not self.data_dir.is_dir(): |
|
raise NotADirectoryError( |
|
f"Data directory {data_dir} is not a directory" |
|
) |
|
self._data: Dict[str, List[Dict[str, Any]]] = dict() |
|
self._results: List[Dict[str, Any]] = [] |
|
|
|
@abstractmethod |
|
def download(self) -> "BaseBenchmark": |
|
r"""Download the benchmark data. |
|
|
|
Returns: |
|
BaseBenchmark: The benchmark instance. |
|
""" |
|
pass |
|
|
|
@abstractmethod |
|
def load(self, force_download: bool = False) -> "BaseBenchmark": |
|
r"""Load the benchmark data. |
|
|
|
Args: |
|
force_download (bool): Whether to force download the data. |
|
|
|
Returns: |
|
BaseBenchmark: The benchmark instance. |
|
""" |
|
pass |
|
|
|
@property |
|
def train(self) -> List[Dict[str, Any]]: |
|
r"""Get the training data. |
|
|
|
Returns: |
|
List[Dict[str, Any]]: The training data. |
|
""" |
|
if not self._data: |
|
logger.info("Data not loaded. Loading data.") |
|
self.load() |
|
return self._data["train"] |
|
|
|
@property |
|
def valid(self) -> List[Dict[str, Any]]: |
|
r"""Get the validation data. |
|
|
|
Returns: |
|
List[Dict[str, Any]]: The validation data. |
|
""" |
|
if not self._data: |
|
logger.info("Data not loaded. Loading data.") |
|
self.load() |
|
return self._data["valid"] |
|
|
|
@property |
|
def test(self) -> List[Dict[str, Any]]: |
|
r"""Get the test data. |
|
|
|
Returns: |
|
List[Dict[str, Any]]: The test data. |
|
""" |
|
if not self._data: |
|
logger.info("Data not loaded. Loading data.") |
|
self.load() |
|
return self._data["test"] |
|
|
|
@abstractmethod |
|
def run( |
|
self, |
|
agent: ChatAgent, |
|
on: Literal["train", "valid", "test"], |
|
randomize: bool = False, |
|
subset: Optional[int] = None, |
|
*args, |
|
**kwargs, |
|
) -> "BaseBenchmark": |
|
r"""Run the benchmark. |
|
|
|
Args: |
|
agent (ChatAgent): The chat agent. |
|
on (str): The data split to run the benchmark on. |
|
randomize (bool): Whether to randomize the data. |
|
subset (int): The subset of the data to run the benchmark on. |
|
|
|
Returns: |
|
BaseBenchmark: The benchmark instance. |
|
""" |
|
pass |
|
|
|
@property |
|
def results(self) -> List[Dict[str, Any]]: |
|
r"""Get the results. |
|
|
|
Returns: |
|
List[Dict[str, Any]]: The results. |
|
""" |
|
return self._results |
|
|