File size: 4,594 Bytes
62da328
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========

import logging
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional

from camel.agents import ChatAgent

logger = logging.getLogger(__name__)


class BaseBenchmark(ABC):
    r"""Base class for benchmarks.

    Attributes:
        name (str): Name of the benchmark.
        data_dir (str): Path to the data directory.
        save_to (str): Path to save the results.
        processes (int): Number of processes to use for parallel
            processing. :(default: :obj:`1`)
    """

    def __init__(
        self, name: str, data_dir: str, save_to: str, processes: int = 1
    ):
        r"""Initialize the benchmark.

        Args:
            name (str): Name of the benchmark.
            data_dir (str): Path to the data directory.
            save_to (str): Path to save the results.
            processes (int): Number of processes to use for parallel
                processing. :(default: :obj:`1`)

        """
        self.name = name
        self.data_dir = Path(data_dir)
        self.processes = processes
        self.save_to = save_to
        if not self.data_dir.exists():
            logger.info(
                f"Data directory {data_dir} does not exist. Creating it."
            )
            self.data_dir.mkdir(parents=True, exist_ok=True)
        if not self.data_dir.is_dir():
            raise NotADirectoryError(
                f"Data directory {data_dir} is not a directory"
            )
        self._data: Dict[str, List[Dict[str, Any]]] = dict()
        self._results: List[Dict[str, Any]] = []

    @abstractmethod
    def download(self) -> "BaseBenchmark":
        r"""Download the benchmark data.

        Returns:
            BaseBenchmark: The benchmark instance.
        """
        pass

    @abstractmethod
    def load(self, force_download: bool = False) -> "BaseBenchmark":
        r"""Load the benchmark data.

        Args:
            force_download (bool): Whether to force download the data.

        Returns:
            BaseBenchmark: The benchmark instance.
        """
        pass

    @property
    def train(self) -> List[Dict[str, Any]]:
        r"""Get the training data.

        Returns:
            List[Dict[str, Any]]: The training data.
        """
        if not self._data:
            logger.info("Data not loaded. Loading data.")
            self.load()
        return self._data["train"]

    @property
    def valid(self) -> List[Dict[str, Any]]:
        r"""Get the validation data.

        Returns:
            List[Dict[str, Any]]: The validation data.
        """
        if not self._data:
            logger.info("Data not loaded. Loading data.")
            self.load()
        return self._data["valid"]

    @property
    def test(self) -> List[Dict[str, Any]]:
        r"""Get the test data.

        Returns:
            List[Dict[str, Any]]: The test data.
        """
        if not self._data:
            logger.info("Data not loaded. Loading data.")
            self.load()
        return self._data["test"]

    @abstractmethod
    def run(
        self,
        agent: ChatAgent,
        on: Literal["train", "valid", "test"],
        randomize: bool = False,
        subset: Optional[int] = None,
        *args,
        **kwargs,
    ) -> "BaseBenchmark":
        r"""Run the benchmark.

        Args:
            agent (ChatAgent): The chat agent.
            on (str): The data split to run the benchmark on.
            randomize (bool): Whether to randomize the data.
            subset (int): The subset of the data to run the benchmark on.

        Returns:
            BaseBenchmark: The benchmark instance.
        """
        pass

    @property
    def results(self) -> List[Dict[str, Any]]:
        r"""Get the results.

        Returns:
            List[Dict[str, Any]]: The results.
        """
        return self._results