Spaces:
Runtime error
Runtime error
"""OpenAI chat wrapper.""" | |
from __future__ import annotations | |
import logging | |
import os | |
import sys | |
from typing import ( | |
TYPE_CHECKING, | |
Any, | |
AsyncIterator, | |
Callable, | |
Dict, | |
Iterator, | |
List, | |
Mapping, | |
Optional, | |
Sequence, | |
Tuple, | |
Type, | |
Union, | |
) | |
from langchain_core._api.deprecation import deprecated | |
from langchain_core.callbacks import ( | |
AsyncCallbackManagerForLLMRun, | |
CallbackManagerForLLMRun, | |
) | |
from langchain_core.language_models import LanguageModelInput | |
from langchain_core.language_models.chat_models import ( | |
BaseChatModel, | |
agenerate_from_stream, | |
generate_from_stream, | |
) | |
from langchain_core.language_models.llms import create_base_retry_decorator | |
from langchain_core.messages import ( | |
AIMessageChunk, | |
BaseMessage, | |
BaseMessageChunk, | |
ChatMessageChunk, | |
FunctionMessageChunk, | |
HumanMessageChunk, | |
SystemMessageChunk, | |
ToolMessageChunk, | |
) | |
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult | |
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator | |
from langchain_core.runnables import Runnable | |
from langchain_core.utils import ( | |
get_from_dict_or_env, | |
get_pydantic_field_names, | |
) | |
from langchain_community.adapters.openai import ( | |
convert_dict_to_message, | |
convert_message_to_dict, | |
) | |
from langchain_community.utils.openai import is_openai_v1 | |
if TYPE_CHECKING: | |
import tiktoken | |
logger = logging.getLogger(__name__) | |
def _import_tiktoken() -> Any: | |
try: | |
import tiktoken | |
except ImportError: | |
raise ImportError( | |
"Could not import tiktoken python package. " | |
"This is needed in order to calculate get_token_ids. " | |
"Please install it with `pip install tiktoken`." | |
) | |
return tiktoken | |
def _create_retry_decorator( | |
llm: ChatOpenAI, | |
run_manager: Optional[ | |
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] | |
] = None, | |
) -> Callable[[Any], Any]: | |
import openai | |
errors = [ | |
openai.error.Timeout, | |
openai.error.APIError, | |
openai.error.APIConnectionError, | |
openai.error.RateLimitError, | |
openai.error.ServiceUnavailableError, | |
] | |
return create_base_retry_decorator( | |
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager | |
) | |
async def acompletion_with_retry( | |
llm: ChatOpenAI, | |
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> Any: | |
"""Use tenacity to retry the async completion call.""" | |
if is_openai_v1(): | |
return await llm.async_client.create(**kwargs) | |
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) | |
async def _completion_with_retry(**kwargs: Any) -> Any: | |
# Use OpenAI's async api https://github.com/openai/openai-python#async-api | |
return await llm.client.acreate(**kwargs) | |
return await _completion_with_retry(**kwargs) | |
def _convert_delta_to_message_chunk( | |
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk] | |
) -> BaseMessageChunk: | |
role = _dict.get("role") | |
content = _dict.get("content") or "" | |
additional_kwargs: Dict = {} | |
if _dict.get("function_call"): | |
function_call = dict(_dict["function_call"]) | |
if "name" in function_call and function_call["name"] is None: | |
function_call["name"] = "" | |
additional_kwargs["function_call"] = function_call | |
if _dict.get("tool_calls"): | |
additional_kwargs["tool_calls"] = _dict["tool_calls"] | |
if role == "user" or default_class == HumanMessageChunk: | |
return HumanMessageChunk(content=content) | |
elif role == "assistant" or default_class == AIMessageChunk: | |
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) | |
elif role == "system" or default_class == SystemMessageChunk: | |
return SystemMessageChunk(content=content) | |
elif role == "function" or default_class == FunctionMessageChunk: | |
return FunctionMessageChunk(content=content, name=_dict["name"]) | |
elif role == "tool" or default_class == ToolMessageChunk: | |
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"]) | |
elif role or default_class == ChatMessageChunk: | |
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type] | |
else: | |
return default_class(content=content) # type: ignore[call-arg] | |
class ChatOpenAI(BaseChatModel): | |
"""`OpenAI` Chat large language models API. | |
To use, you should have the ``openai`` python package installed, and the | |
environment variable ``OPENAI_API_KEY`` set with your API key. | |
Any parameters that are valid to be passed to the openai.create call can be passed | |
in, even if not explicitly saved on this class. | |
Example: | |
.. code-block:: python | |
from langchain_community.chat_models import ChatOpenAI | |
openai = ChatOpenAI(model="gpt-3.5-turbo") | |
""" | |
def lc_secrets(self) -> Dict[str, str]: | |
return {"openai_api_key": "OPENAI_API_KEY"} | |
def get_lc_namespace(cls) -> List[str]: | |
"""Get the namespace of the langchain object.""" | |
return ["langchain", "chat_models", "openai"] | |
def lc_attributes(self) -> Dict[str, Any]: | |
attributes: Dict[str, Any] = {} | |
if self.openai_organization: | |
attributes["openai_organization"] = self.openai_organization | |
if self.openai_api_base: | |
attributes["openai_api_base"] = self.openai_api_base | |
if self.openai_proxy: | |
attributes["openai_proxy"] = self.openai_proxy | |
return attributes | |
def is_lc_serializable(cls) -> bool: | |
"""Return whether this model can be serialized by Langchain.""" | |
return True | |
client: Any = Field(default=None, exclude=True) #: :meta private: | |
async_client: Any = Field(default=None, exclude=True) #: :meta private: | |
model_name: str = Field(default="gpt-3.5-turbo", alias="model") | |
"""Model name to use.""" | |
temperature: float = 0.7 | |
"""What sampling temperature to use.""" | |
model_kwargs: Dict[str, Any] = Field(default_factory=dict) | |
"""Holds any model parameters valid for `create` call not explicitly specified.""" | |
# When updating this to use a SecretStr | |
# Check for classes that derive from this class (as some of them | |
# may assume openai_api_key is a str) | |
openai_api_key: Optional[str] = Field(default=None, alias="api_key") | |
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided.""" | |
openai_api_base: Optional[str] = Field(default=None, alias="base_url") | |
"""Base URL path for API requests, leave blank if not using a proxy or service | |
emulator.""" | |
openai_organization: Optional[str] = Field(default=None, alias="organization") | |
"""Automatically inferred from env var `OPENAI_ORG_ID` if not provided.""" | |
# to support explicit proxy for OpenAI | |
openai_proxy: Optional[str] = None | |
request_timeout: Union[float, Tuple[float, float], Any, None] = Field( | |
default=None, alias="timeout" | |
) | |
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or | |
None.""" | |
max_retries: int = Field(default=2) | |
"""Maximum number of retries to make when generating.""" | |
streaming: bool = False | |
"""Whether to stream the results or not.""" | |
n: int = 1 | |
"""Number of chat completions to generate for each prompt.""" | |
max_tokens: Optional[int] = None | |
"""Maximum number of tokens to generate.""" | |
tiktoken_model_name: Optional[str] = None | |
"""The model name to pass to tiktoken when using this class. | |
Tiktoken is used to count the number of tokens in documents to constrain | |
them to be under a certain limit. By default, when set to None, this will | |
be the same as the embedding model name. However, there are some cases | |
where you may want to use this Embedding class with a model name not | |
supported by tiktoken. This can include when using Azure embeddings or | |
when using one of the many model providers that expose an OpenAI-like | |
API but with different models. In those cases, in order to avoid erroring | |
when tiktoken is called, you can specify a model name to use here.""" | |
default_headers: Union[Mapping[str, str], None] = None | |
default_query: Union[Mapping[str, object], None] = None | |
# Configure a custom httpx client. See the | |
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details. | |
http_client: Union[Any, None] = None | |
"""Optional httpx.Client.""" | |
class Config: | |
"""Configuration for this pydantic object.""" | |
allow_population_by_field_name = True | |
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: | |
"""Build extra kwargs from additional params that were passed in.""" | |
all_required_field_names = get_pydantic_field_names(cls) | |
extra = values.get("model_kwargs", {}) | |
for field_name in list(values): | |
if field_name in extra: | |
raise ValueError(f"Found {field_name} supplied twice.") | |
if field_name not in all_required_field_names: | |
logger.warning( | |
f"""WARNING! {field_name} is not default parameter. | |
{field_name} was transferred to model_kwargs. | |
Please confirm that {field_name} is what you intended.""" | |
) | |
extra[field_name] = values.pop(field_name) | |
invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) | |
if invalid_model_kwargs: | |
raise ValueError( | |
f"Parameters {invalid_model_kwargs} should be specified explicitly. " | |
f"Instead they were passed in as part of `model_kwargs` parameter." | |
) | |
values["model_kwargs"] = extra | |
return values | |
def validate_environment(cls, values: Dict) -> Dict: | |
"""Validate that api key and python package exists in environment.""" | |
if values["n"] < 1: | |
raise ValueError("n must be at least 1.") | |
if values["n"] > 1 and values["streaming"]: | |
raise ValueError("n must be 1 when streaming.") | |
values["openai_api_key"] = get_from_dict_or_env( | |
values, "openai_api_key", "OPENAI_API_KEY" | |
) | |
# Check OPENAI_ORGANIZATION for backwards compatibility. | |
values["openai_organization"] = ( | |
values["openai_organization"] | |
or os.getenv("OPENAI_ORG_ID") | |
or os.getenv("OPENAI_ORGANIZATION") | |
) | |
values["openai_api_base"] = values["openai_api_base"] or os.getenv( | |
"OPENAI_API_BASE" | |
) | |
values["openai_proxy"] = get_from_dict_or_env( | |
values, | |
"openai_proxy", | |
"OPENAI_PROXY", | |
default="", | |
) | |
try: | |
import openai | |
except ImportError: | |
raise ImportError( | |
"Could not import openai python package. " | |
"Please install it with `pip install openai`." | |
) | |
if is_openai_v1(): | |
client_params = { | |
"api_key": values["openai_api_key"], | |
"organization": values["openai_organization"], | |
"base_url": values["openai_api_base"], | |
"timeout": values["request_timeout"], | |
"max_retries": values["max_retries"], | |
"default_headers": values["default_headers"], | |
"default_query": values["default_query"], | |
"http_client": values["http_client"], | |
} | |
if not values.get("client"): | |
values["client"] = openai.OpenAI(**client_params).chat.completions | |
if not values.get("async_client"): | |
values["async_client"] = openai.AsyncOpenAI( | |
**client_params | |
).chat.completions | |
elif not values.get("client"): | |
values["client"] = openai.ChatCompletion | |
else: | |
pass | |
return values | |
def _default_params(self) -> Dict[str, Any]: | |
"""Get the default parameters for calling OpenAI API.""" | |
params = { | |
"model": self.model_name, | |
"stream": self.streaming, | |
"n": self.n, | |
"temperature": self.temperature, | |
**self.model_kwargs, | |
} | |
if self.max_tokens is not None: | |
params["max_tokens"] = self.max_tokens | |
if self.request_timeout is not None and not is_openai_v1(): | |
params["request_timeout"] = self.request_timeout | |
return params | |
def completion_with_retry( | |
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any | |
) -> Any: | |
"""Use tenacity to retry the completion call.""" | |
if is_openai_v1(): | |
return self.client.create(**kwargs) | |
retry_decorator = _create_retry_decorator(self, run_manager=run_manager) | |
def _completion_with_retry(**kwargs: Any) -> Any: | |
return self.client.create(**kwargs) | |
return _completion_with_retry(**kwargs) | |
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: | |
overall_token_usage: dict = {} | |
system_fingerprint = None | |
for output in llm_outputs: | |
if output is None: | |
# Happens in streaming | |
continue | |
token_usage = output["token_usage"] | |
if token_usage is not None: | |
for k, v in token_usage.items(): | |
if k in overall_token_usage: | |
overall_token_usage[k] += v | |
else: | |
overall_token_usage[k] = v | |
if system_fingerprint is None: | |
system_fingerprint = output.get("system_fingerprint") | |
combined = {"token_usage": overall_token_usage, "model_name": self.model_name} | |
if system_fingerprint: | |
combined["system_fingerprint"] = system_fingerprint | |
return combined | |
def _stream( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> Iterator[ChatGenerationChunk]: | |
message_dicts, params = self._create_message_dicts(messages, stop) | |
params = {**params, **kwargs, "stream": True} | |
default_chunk_class = AIMessageChunk | |
for chunk in self.completion_with_retry( | |
messages=message_dicts, run_manager=run_manager, **params | |
): | |
if not isinstance(chunk, dict): | |
chunk = chunk.dict() | |
if len(chunk["choices"]) == 0: | |
continue | |
choice = chunk["choices"][0] | |
chunk = _convert_delta_to_message_chunk( | |
choice["delta"], default_chunk_class | |
) | |
finish_reason = choice.get("finish_reason") | |
generation_info = ( | |
dict(finish_reason=finish_reason) if finish_reason is not None else None | |
) | |
default_chunk_class = chunk.__class__ | |
cg_chunk = ChatGenerationChunk( | |
message=chunk, generation_info=generation_info | |
) | |
if run_manager: | |
run_manager.on_llm_new_token(cg_chunk.text, chunk=cg_chunk) | |
yield cg_chunk | |
def _generate( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
stream: Optional[bool] = None, | |
**kwargs: Any, | |
) -> ChatResult: | |
should_stream = stream if stream is not None else self.streaming | |
if should_stream: | |
stream_iter = self._stream( | |
messages, stop=stop, run_manager=run_manager, **kwargs | |
) | |
return generate_from_stream(stream_iter) | |
message_dicts, params = self._create_message_dicts(messages, stop) | |
params = { | |
**params, | |
**({"stream": stream} if stream is not None else {}), | |
**kwargs, | |
} | |
response = self.completion_with_retry( | |
messages=message_dicts, run_manager=run_manager, **params | |
) | |
return self._create_chat_result(response) | |
def _create_message_dicts( | |
self, messages: List[BaseMessage], stop: Optional[List[str]] | |
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: | |
params = self._client_params | |
if stop is not None: | |
if "stop" in params: | |
raise ValueError("`stop` found in both the input and default params.") | |
params["stop"] = stop | |
message_dicts = [convert_message_to_dict(m) for m in messages] | |
return message_dicts, params | |
def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult: | |
generations = [] | |
if not isinstance(response, dict): | |
response = response.dict() | |
for res in response["choices"]: | |
message = convert_dict_to_message(res["message"]) | |
generation_info = dict(finish_reason=res.get("finish_reason")) | |
if "logprobs" in res: | |
generation_info["logprobs"] = res["logprobs"] | |
gen = ChatGeneration( | |
message=message, | |
generation_info=generation_info, | |
) | |
generations.append(gen) | |
token_usage = response.get("usage", {}) | |
llm_output = { | |
"token_usage": token_usage, | |
"model_name": self.model_name, | |
"system_fingerprint": response.get("system_fingerprint", ""), | |
} | |
return ChatResult(generations=generations, llm_output=llm_output) | |
async def _astream( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> AsyncIterator[ChatGenerationChunk]: | |
message_dicts, params = self._create_message_dicts(messages, stop) | |
params = {**params, **kwargs, "stream": True} | |
default_chunk_class = AIMessageChunk | |
async for chunk in await acompletion_with_retry( | |
self, messages=message_dicts, run_manager=run_manager, **params | |
): | |
if not isinstance(chunk, dict): | |
chunk = chunk.dict() | |
if len(chunk["choices"]) == 0: | |
continue | |
choice = chunk["choices"][0] | |
chunk = _convert_delta_to_message_chunk( | |
choice["delta"], default_chunk_class | |
) | |
finish_reason = choice.get("finish_reason") | |
generation_info = ( | |
dict(finish_reason=finish_reason) if finish_reason is not None else None | |
) | |
default_chunk_class = chunk.__class__ | |
cg_chunk = ChatGenerationChunk( | |
message=chunk, generation_info=generation_info | |
) | |
if run_manager: | |
await run_manager.on_llm_new_token(token=cg_chunk.text, chunk=cg_chunk) | |
yield cg_chunk | |
async def _agenerate( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | |
stream: Optional[bool] = None, | |
**kwargs: Any, | |
) -> ChatResult: | |
should_stream = stream if stream is not None else self.streaming | |
if should_stream: | |
stream_iter = self._astream( | |
messages, stop=stop, run_manager=run_manager, **kwargs | |
) | |
return await agenerate_from_stream(stream_iter) | |
message_dicts, params = self._create_message_dicts(messages, stop) | |
params = { | |
**params, | |
**({"stream": stream} if stream is not None else {}), | |
**kwargs, | |
} | |
response = await acompletion_with_retry( | |
self, messages=message_dicts, run_manager=run_manager, **params | |
) | |
return self._create_chat_result(response) | |
def _identifying_params(self) -> Dict[str, Any]: | |
"""Get the identifying parameters.""" | |
return {**{"model_name": self.model_name}, **self._default_params} | |
def _client_params(self) -> Dict[str, Any]: | |
"""Get the parameters used for the openai client.""" | |
openai_creds: Dict[str, Any] = { | |
"model": self.model_name, | |
} | |
if not is_openai_v1(): | |
openai_creds.update( | |
{ | |
"api_key": self.openai_api_key, | |
"api_base": self.openai_api_base, | |
"organization": self.openai_organization, | |
} | |
) | |
if self.openai_proxy: | |
import openai | |
openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} | |
return {**self._default_params, **openai_creds} | |
def _get_invocation_params( | |
self, stop: Optional[List[str]] = None, **kwargs: Any | |
) -> Dict[str, Any]: | |
"""Get the parameters used to invoke the model.""" | |
return { | |
"model": self.model_name, | |
**super()._get_invocation_params(stop=stop), | |
**self._default_params, | |
**kwargs, | |
} | |
def _llm_type(self) -> str: | |
"""Return type of chat model.""" | |
return "openai-chat" | |
def _get_encoding_model(self) -> Tuple[str, tiktoken.Encoding]: | |
tiktoken_ = _import_tiktoken() | |
if self.tiktoken_model_name is not None: | |
model = self.tiktoken_model_name | |
else: | |
model = self.model_name | |
if model == "gpt-3.5-turbo": | |
# gpt-3.5-turbo may change over time. | |
# Returning num tokens assuming gpt-3.5-turbo-0301. | |
model = "gpt-3.5-turbo-0301" | |
elif model == "gpt-4": | |
# gpt-4 may change over time. | |
# Returning num tokens assuming gpt-4-0314. | |
model = "gpt-4-0314" | |
# Returns the number of tokens used by a list of messages. | |
try: | |
encoding = tiktoken_.encoding_for_model(model) | |
except KeyError: | |
logger.warning("Warning: model not found. Using cl100k_base encoding.") | |
model = "cl100k_base" | |
encoding = tiktoken_.get_encoding(model) | |
return model, encoding | |
def get_token_ids(self, text: str) -> List[int]: | |
"""Get the tokens present in the text with tiktoken package.""" | |
# tiktoken NOT supported for Python 3.7 or below | |
if sys.version_info[1] <= 7: | |
return super().get_token_ids(text) | |
_, encoding_model = self._get_encoding_model() | |
return encoding_model.encode(text) | |
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: | |
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. | |
Official documentation: https://github.com/openai/openai-cookbook/blob/ | |
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" | |
if sys.version_info[1] <= 7: | |
return super().get_num_tokens_from_messages(messages) | |
model, encoding = self._get_encoding_model() | |
if model.startswith("gpt-3.5-turbo-0301"): | |
# every message follows <im_start>{role/name}\n{content}<im_end>\n | |
tokens_per_message = 4 | |
# if there's a name, the role is omitted | |
tokens_per_name = -1 | |
elif model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"): | |
tokens_per_message = 3 | |
tokens_per_name = 1 | |
else: | |
raise NotImplementedError( | |
f"get_num_tokens_from_messages() is not presently implemented " | |
f"for model {model}." | |
"See https://github.com/openai/openai-python/blob/main/chatml.md for " | |
"information on how messages are converted to tokens." | |
) | |
num_tokens = 0 | |
messages_dict = [convert_message_to_dict(m) for m in messages] | |
for message in messages_dict: | |
num_tokens += tokens_per_message | |
for key, value in message.items(): | |
# Cast str(value) in case the message value is not a string | |
# This occurs with function messages | |
num_tokens += len(encoding.encode(str(value))) | |
if key == "name": | |
num_tokens += tokens_per_name | |
# every reply is primed with <im_start>assistant | |
num_tokens += 3 | |
return num_tokens | |
def bind_functions( | |
self, | |
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]], | |
function_call: Optional[str] = None, | |
**kwargs: Any, | |
) -> Runnable[LanguageModelInput, BaseMessage]: | |
"""Bind functions (and other objects) to this chat model. | |
Args: | |
functions: A list of function definitions to bind to this chat model. | |
Can be a dictionary, pydantic model, or callable. Pydantic | |
models and callables will be automatically converted to | |
their schema dictionary representation. | |
function_call: Which function to require the model to call. | |
Must be the name of the single provided function or | |
"auto" to automatically determine which function to call | |
(if any). | |
kwargs: Any additional parameters to pass to the | |
:class:`~langchain.runnable.Runnable` constructor. | |
""" | |
from langchain.chains.openai_functions.base import convert_to_openai_function | |
formatted_functions = [convert_to_openai_function(fn) for fn in functions] | |
if function_call is not None: | |
if len(formatted_functions) != 1: | |
raise ValueError( | |
"When specifying `function_call`, you must provide exactly one " | |
"function." | |
) | |
if formatted_functions[0]["name"] != function_call: | |
raise ValueError( | |
f"Function call {function_call} was specified, but the only " | |
f"provided function was {formatted_functions[0]['name']}." | |
) | |
function_call_ = {"name": function_call} | |
kwargs = {**kwargs, "function_call": function_call_} | |
return super().bind( | |
functions=formatted_functions, | |
**kwargs, | |
) | |