Spaces:
Runtime error
Runtime error
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, cast | |
from langchain_core._api.deprecation import deprecated | |
from langchain_core.callbacks import ( | |
AsyncCallbackManagerForLLMRun, | |
CallbackManagerForLLMRun, | |
) | |
from langchain_core.language_models.chat_models import ( | |
BaseChatModel, | |
agenerate_from_stream, | |
generate_from_stream, | |
) | |
from langchain_core.messages import ( | |
AIMessage, | |
AIMessageChunk, | |
BaseMessage, | |
ChatMessage, | |
HumanMessage, | |
SystemMessage, | |
) | |
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult | |
from langchain_core.prompt_values import PromptValue | |
from langchain_community.llms.anthropic import _AnthropicCommon | |
def _convert_one_message_to_text( | |
message: BaseMessage, | |
human_prompt: str, | |
ai_prompt: str, | |
) -> str: | |
content = cast(str, message.content) | |
if isinstance(message, ChatMessage): | |
message_text = f"\n\n{message.role.capitalize()}: {content}" | |
elif isinstance(message, HumanMessage): | |
message_text = f"{human_prompt} {content}" | |
elif isinstance(message, AIMessage): | |
message_text = f"{ai_prompt} {content}" | |
elif isinstance(message, SystemMessage): | |
message_text = content | |
else: | |
raise ValueError(f"Got unknown type {message}") | |
return message_text | |
def convert_messages_to_prompt_anthropic( | |
messages: List[BaseMessage], | |
*, | |
human_prompt: str = "\n\nHuman:", | |
ai_prompt: str = "\n\nAssistant:", | |
) -> str: | |
"""Format a list of messages into a full prompt for the Anthropic model | |
Args: | |
messages (List[BaseMessage]): List of BaseMessage to combine. | |
human_prompt (str, optional): Human prompt tag. Defaults to "\n\nHuman:". | |
ai_prompt (str, optional): AI prompt tag. Defaults to "\n\nAssistant:". | |
Returns: | |
str: Combined string with necessary human_prompt and ai_prompt tags. | |
""" | |
messages = messages.copy() # don't mutate the original list | |
if not isinstance(messages[-1], AIMessage): | |
messages.append(AIMessage(content="")) | |
text = "".join( | |
_convert_one_message_to_text(message, human_prompt, ai_prompt) | |
for message in messages | |
) | |
# trim off the trailing ' ' that might come from the "Assistant: " | |
return text.rstrip() | |
class ChatAnthropic(BaseChatModel, _AnthropicCommon): | |
"""`Anthropic` chat large language models. | |
To use, you should have the ``anthropic`` python package installed, and the | |
environment variable ``ANTHROPIC_API_KEY`` set with your API key, or pass | |
it as a named parameter to the constructor. | |
Example: | |
.. code-block:: python | |
import anthropic | |
from langchain_community.chat_models import ChatAnthropic | |
model = ChatAnthropic(model="<model_name>", anthropic_api_key="my-api-key") | |
""" | |
class Config: | |
"""Configuration for this pydantic object.""" | |
allow_population_by_field_name = True | |
arbitrary_types_allowed = True | |
def lc_secrets(self) -> Dict[str, str]: | |
return {"anthropic_api_key": "ANTHROPIC_API_KEY"} | |
def _llm_type(self) -> str: | |
"""Return type of chat model.""" | |
return "anthropic-chat" | |
def is_lc_serializable(cls) -> bool: | |
"""Return whether this model can be serialized by Langchain.""" | |
return True | |
def get_lc_namespace(cls) -> List[str]: | |
"""Get the namespace of the langchain object.""" | |
return ["langchain", "chat_models", "anthropic"] | |
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str: | |
"""Format a list of messages into a full prompt for the Anthropic model | |
Args: | |
messages (List[BaseMessage]): List of BaseMessage to combine. | |
Returns: | |
str: Combined string with necessary HUMAN_PROMPT and AI_PROMPT tags. | |
""" | |
prompt_params = {} | |
if self.HUMAN_PROMPT: | |
prompt_params["human_prompt"] = self.HUMAN_PROMPT | |
if self.AI_PROMPT: | |
prompt_params["ai_prompt"] = self.AI_PROMPT | |
return convert_messages_to_prompt_anthropic(messages=messages, **prompt_params) | |
def convert_prompt(self, prompt: PromptValue) -> str: | |
return self._convert_messages_to_prompt(prompt.to_messages()) | |
def _stream( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> Iterator[ChatGenerationChunk]: | |
prompt = self._convert_messages_to_prompt(messages) | |
params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs} | |
if stop: | |
params["stop_sequences"] = stop | |
stream_resp = self.client.completions.create(**params, stream=True) | |
for data in stream_resp: | |
delta = data.completion | |
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta)) | |
if run_manager: | |
run_manager.on_llm_new_token(delta, chunk=chunk) | |
yield chunk | |
async def _astream( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> AsyncIterator[ChatGenerationChunk]: | |
prompt = self._convert_messages_to_prompt(messages) | |
params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs} | |
if stop: | |
params["stop_sequences"] = stop | |
stream_resp = await self.async_client.completions.create(**params, stream=True) | |
async for data in stream_resp: | |
delta = data.completion | |
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta)) | |
if run_manager: | |
await run_manager.on_llm_new_token(delta, chunk=chunk) | |
yield chunk | |
def _generate( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> ChatResult: | |
if self.streaming: | |
stream_iter = self._stream( | |
messages, stop=stop, run_manager=run_manager, **kwargs | |
) | |
return generate_from_stream(stream_iter) | |
prompt = self._convert_messages_to_prompt( | |
messages, | |
) | |
params: Dict[str, Any] = { | |
"prompt": prompt, | |
**self._default_params, | |
**kwargs, | |
} | |
if stop: | |
params["stop_sequences"] = stop | |
response = self.client.completions.create(**params) | |
completion = response.completion | |
message = AIMessage(content=completion) | |
return ChatResult(generations=[ChatGeneration(message=message)]) | |
async def _agenerate( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> ChatResult: | |
if self.streaming: | |
stream_iter = self._astream( | |
messages, stop=stop, run_manager=run_manager, **kwargs | |
) | |
return await agenerate_from_stream(stream_iter) | |
prompt = self._convert_messages_to_prompt( | |
messages, | |
) | |
params: Dict[str, Any] = { | |
"prompt": prompt, | |
**self._default_params, | |
**kwargs, | |
} | |
if stop: | |
params["stop_sequences"] = stop | |
response = await self.async_client.completions.create(**params) | |
completion = response.completion | |
message = AIMessage(content=completion) | |
return ChatResult(generations=[ChatGeneration(message=message)]) | |
def get_num_tokens(self, text: str) -> int: | |
"""Calculate number of tokens.""" | |
if not self.count_tokens: | |
raise NameError("Please ensure the anthropic package is loaded") | |
return self.count_tokens(text) | |