File size: 5,587 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""EverlyAI Endpoints chat wrapper. Relies heavily on ChatOpenAI."""
from __future__ import annotations

import logging
import sys
from typing import TYPE_CHECKING, Dict, Optional, Set

from langchain_core.messages import BaseMessage
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.utils import get_from_dict_or_env

from langchain_community.adapters.openai import convert_message_to_dict
from langchain_community.chat_models.openai import (
    ChatOpenAI,
    _import_tiktoken,
)

if TYPE_CHECKING:
    import tiktoken

logger = logging.getLogger(__name__)


DEFAULT_API_BASE = "https://everlyai.xyz/hosted"
DEFAULT_MODEL = "meta-llama/Llama-2-7b-chat-hf"


class ChatEverlyAI(ChatOpenAI):
    """`EverlyAI` Chat large language models.

    To use, you should have the ``openai`` python package installed, and the
    environment variable ``EVERLYAI_API_KEY`` set with your API key.
    Alternatively, you can use the everlyai_api_key keyword argument.

    Any parameters that are valid to be passed to the `openai.create` call can be passed
    in, even if not explicitly saved on this class.

    Example:
        .. code-block:: python

            from langchain_community.chat_models import ChatEverlyAI
            chat = ChatEverlyAI(model_name="meta-llama/Llama-2-7b-chat-hf")
    """

    @property
    def _llm_type(self) -> str:
        """Return type of chat model."""
        return "everlyai-chat"

    @property
    def lc_secrets(self) -> Dict[str, str]:
        return {"everlyai_api_key": "EVERLYAI_API_KEY"}

    @classmethod
    def is_lc_serializable(cls) -> bool:
        return False

    everlyai_api_key: Optional[str] = None
    """EverlyAI Endpoints API keys."""
    model_name: str = Field(default=DEFAULT_MODEL, alias="model")
    """Model name to use."""
    everlyai_api_base: str = DEFAULT_API_BASE
    """Base URL path for API requests."""
    available_models: Optional[Set[str]] = None
    """Available models from EverlyAI API."""

    @staticmethod
    def get_available_models() -> Set[str]:
        """Get available models from EverlyAI API."""
        # EverlyAI doesn't yet support dynamically query for available models.
        return set(
            [
                "meta-llama/Llama-2-7b-chat-hf",
                "meta-llama/Llama-2-13b-chat-hf-quantized",
            ]
        )

    @root_validator(pre=True)
    def validate_environment_override(cls, values: dict) -> dict:
        """Validate that api key and python package exists in environment."""
        values["openai_api_key"] = get_from_dict_or_env(
            values,
            "everlyai_api_key",
            "EVERLYAI_API_KEY",
        )
        values["openai_api_base"] = DEFAULT_API_BASE

        try:
            import openai

        except ImportError as e:
            raise ImportError(
                "Could not import openai python package. "
                "Please install it with `pip install openai`.",
            ) from e
        try:
            values["client"] = openai.ChatCompletion
        except AttributeError as exc:
            raise ValueError(
                "`openai` has no `ChatCompletion` attribute, this is likely "
                "due to an old version of the openai package. Try upgrading it "
                "with `pip install --upgrade openai`.",
            ) from exc

        if "model_name" not in values.keys():
            values["model_name"] = DEFAULT_MODEL

        model_name = values["model_name"]

        available_models = cls.get_available_models()

        if model_name not in available_models:
            raise ValueError(
                f"Model name {model_name} not found in available models: "
                f"{available_models}.",
            )

        values["available_models"] = available_models

        return values

    def _get_encoding_model(self) -> tuple[str, tiktoken.Encoding]:
        tiktoken_ = _import_tiktoken()
        if self.tiktoken_model_name is not None:
            model = self.tiktoken_model_name
        else:
            model = self.model_name
        # Returns the number of tokens used by a list of messages.
        try:
            encoding = tiktoken_.encoding_for_model("gpt-3.5-turbo-0301")
        except KeyError:
            logger.warning("Warning: model not found. Using cl100k_base encoding.")
            model = "cl100k_base"
            encoding = tiktoken_.get_encoding(model)
        return model, encoding

    def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int:
        """Calculate num tokens with tiktoken package.

        Official documentation: https://github.com/openai/openai-cookbook/blob/
        main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
        if sys.version_info[1] <= 7:
            return super().get_num_tokens_from_messages(messages)
        model, encoding = self._get_encoding_model()
        tokens_per_message = 3
        tokens_per_name = 1
        num_tokens = 0
        messages_dict = [convert_message_to_dict(m) for m in messages]
        for message in messages_dict:
            num_tokens += tokens_per_message
            for key, value in message.items():
                # Cast str(value) in case the message value is not a string
                # This occurs with function messages
                num_tokens += len(encoding.encode(str(value)))
                if key == "name":
                    num_tokens += tokens_per_name
        # every reply is primed with <im_start>assistant
        num_tokens += 3
        return num_tokens