File size: 7,159 Bytes
62da328 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import logging
from itertools import cycle
from random import choice
from typing import (
Any,
Callable,
Dict,
List,
Union,
)
from openai import Stream
from camel.messages import OpenAIMessage
from camel.models.base_model import BaseModelBackend
from camel.types import (
ChatCompletion,
ChatCompletionChunk,
UnifiedModelType,
)
from camel.utils import BaseTokenCounter
logger = logging.getLogger(__name__)
class ModelProcessingError(Exception):
r"""Raised when an error occurs during model processing."""
pass
class ModelManager:
r"""ModelManager choosing a model from provided list.
Models are picked according to defined strategy.
Args:
models(Union[BaseModelBackend, List[BaseModelBackend]]):
model backend or list of model backends
(e.g., model instances, APIs)
scheduling_strategy (str): name of function that defines how
to select the next model. (default: :str:`round_robin`)
"""
def __init__(
self,
models: Union[BaseModelBackend, List[BaseModelBackend]],
scheduling_strategy: str = "round_robin",
):
if isinstance(models, list):
self.models = models
else:
self.models = [models]
self.models_cycle = cycle(self.models)
self.current_model = self.models[0]
# Set the scheduling strategy; default is round-robin
try:
self.scheduling_strategy = getattr(self, scheduling_strategy)
except AttributeError:
logger.warning(
f"Provided strategy: {scheduling_strategy} is not implemented."
f"Using default 'round robin'"
)
self.scheduling_strategy = self.round_robin
@property
def model_type(self) -> UnifiedModelType:
r"""Return type of the current model.
Returns:
Union[ModelType, str]: Current model type.
"""
return self.current_model.model_type
@property
def model_config_dict(self) -> Dict[str, Any]:
r"""Return model_config_dict of the current model.
Returns:
Dict[str, Any]: Config dictionary of the current model.
"""
return self.current_model.model_config_dict
@model_config_dict.setter
def model_config_dict(self, model_config_dict: Dict[str, Any]):
r"""Set model_config_dict to the current model.
Args:
model_config_dict (Dict[str, Any]): Config dictionary to be set at
current model.
"""
self.current_model.model_config_dict = model_config_dict
@property
def current_model_index(self) -> int:
r"""Return the index of current model in self.models list.
Returns:
int: index of current model in given list of models.
"""
return self.models.index(self.current_model)
@property
def token_limit(self):
r"""Returns the maximum token limit for current model.
This method retrieves the maximum token limit either from the
`model_config_dict` or from the model's default token limit.
Returns:
int: The maximum token limit for the given model.
"""
return self.current_model.token_limit
@property
def token_counter(self) -> BaseTokenCounter:
r"""Return token_counter of the current model.
Returns:
BaseTokenCounter: The token counter following the model's
tokenization style.
"""
return self.current_model.token_counter
def add_strategy(self, name: str, strategy_fn: Callable):
r"""Add a scheduling strategy method provided by user in case when none
of existent strategies fits.
When custom strategy is provided, it will be set as
"self.scheduling_strategy" attribute.
Args:
name (str): The name of the strategy.
strategy_fn (Callable): The scheduling strategy function.
"""
if not callable(strategy_fn):
raise ValueError("strategy_fn must be a callable function.")
setattr(self, name, strategy_fn.__get__(self))
self.scheduling_strategy = getattr(self, name)
logger.info(f"Custom strategy '{name}' added.")
# Strategies
def round_robin(self) -> BaseModelBackend:
r"""Return models one by one in simple round-robin fashion.
Returns:
BaseModelBackend for processing incoming messages.
"""
return next(self.models_cycle)
def always_first(self) -> BaseModelBackend:
r"""Always return the first model from self.models.
Returns:
BaseModelBackend for processing incoming messages.
"""
return self.models[0]
def random_model(self) -> BaseModelBackend:
r"""Return random model from self.models list.
Returns:
BaseModelBackend for processing incoming messages.
"""
return choice(self.models)
def run(
self, messages: List[OpenAIMessage]
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
r"""Process a list of messages by selecting a model based on
the scheduling strategy.
Sends the entire list of messages to the selected model,
and returns a single response.
Args:
messages (List[OpenAIMessage]): Message list with the chat
history in OpenAI API format.
Returns:
Union[ChatCompletion, Stream[ChatCompletionChunk]]:
`ChatCompletion` in the non-stream mode, or
`Stream[ChatCompletionChunk]` in the stream mode.
"""
self.current_model = self.scheduling_strategy()
# Pass all messages to the selected model and get the response
try:
response = self.current_model.run(messages)
except Exception as exc:
logger.error(f"Error processing with model: {self.current_model}")
if self.scheduling_strategy == self.always_first:
self.scheduling_strategy = self.round_robin
logger.warning(
"The scheduling strategy has been changed to 'round_robin'"
)
# Skip already used one
self.current_model = self.scheduling_strategy()
raise exc
return response
|