Spaces:
Runtime error
Runtime error
import re | |
from collections import defaultdict | |
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union | |
from langchain_core._api.deprecation import deprecated | |
from langchain_core.callbacks import ( | |
CallbackManagerForLLMRun, | |
) | |
from langchain_core.language_models.chat_models import BaseChatModel | |
from langchain_core.messages import ( | |
AIMessage, | |
AIMessageChunk, | |
BaseMessage, | |
ChatMessage, | |
HumanMessage, | |
SystemMessage, | |
) | |
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult | |
from langchain_core.pydantic_v1 import Extra | |
from langchain_community.chat_models.anthropic import ( | |
convert_messages_to_prompt_anthropic, | |
) | |
from langchain_community.chat_models.meta import convert_messages_to_prompt_llama | |
from langchain_community.llms.bedrock import BedrockBase | |
from langchain_community.utilities.anthropic import ( | |
get_num_tokens_anthropic, | |
get_token_ids_anthropic, | |
) | |
def _convert_one_message_to_text_mistral(message: BaseMessage) -> str: | |
if isinstance(message, ChatMessage): | |
message_text = f"\n\n{message.role.capitalize()}: {message.content}" | |
elif isinstance(message, HumanMessage): | |
message_text = f"[INST] {message.content} [/INST]" | |
elif isinstance(message, AIMessage): | |
message_text = f"{message.content}" | |
elif isinstance(message, SystemMessage): | |
message_text = f"<<SYS>> {message.content} <</SYS>>" | |
else: | |
raise ValueError(f"Got unknown type {message}") | |
return message_text | |
def convert_messages_to_prompt_mistral(messages: List[BaseMessage]) -> str: | |
"""Convert a list of messages to a prompt for mistral.""" | |
return "\n".join( | |
[_convert_one_message_to_text_mistral(message) for message in messages] | |
) | |
def _format_image(image_url: str) -> Dict: | |
""" | |
Formats an image of format data:image/jpeg;base64,{b64_string} | |
to a dict for anthropic api | |
{ | |
"type": "base64", | |
"media_type": "image/jpeg", | |
"data": "/9j/4AAQSkZJRg...", | |
} | |
And throws an error if it's not a b64 image | |
""" | |
regex = r"^data:(?P<media_type>image/.+);base64,(?P<data>.+)$" | |
match = re.match(regex, image_url) | |
if match is None: | |
raise ValueError( | |
"Anthropic only supports base64-encoded images currently." | |
" Example: data:image/png;base64,'/9j/4AAQSk'..." | |
) | |
return { | |
"type": "base64", | |
"media_type": match.group("media_type"), | |
"data": match.group("data"), | |
} | |
def _format_anthropic_messages( | |
messages: List[BaseMessage], | |
) -> Tuple[Optional[str], List[Dict]]: | |
"""Format messages for anthropic.""" | |
""" | |
[ | |
{ | |
"role": _message_type_lookups[m.type], | |
"content": [_AnthropicMessageContent(text=m.content).dict()], | |
} | |
for m in messages | |
] | |
""" | |
system: Optional[str] = None | |
formatted_messages: List[Dict] = [] | |
for i, message in enumerate(messages): | |
if message.type == "system": | |
if i != 0: | |
raise ValueError("System message must be at beginning of message list.") | |
if not isinstance(message.content, str): | |
raise ValueError( | |
"System message must be a string, " | |
f"instead was: {type(message.content)}" | |
) | |
system = message.content | |
continue | |
role = _message_type_lookups[message.type] | |
content: Union[str, List[Dict]] | |
if not isinstance(message.content, str): | |
# parse as dict | |
assert isinstance( | |
message.content, list | |
), "Anthropic message content must be str or list of dicts" | |
# populate content | |
content = [] | |
for item in message.content: | |
if isinstance(item, str): | |
content.append( | |
{ | |
"type": "text", | |
"text": item, | |
} | |
) | |
elif isinstance(item, dict): | |
if "type" not in item: | |
raise ValueError("Dict content item must have a type key") | |
if item["type"] == "image_url": | |
# convert format | |
source = _format_image(item["image_url"]["url"]) | |
content.append( | |
{ | |
"type": "image", | |
"source": source, | |
} | |
) | |
else: | |
content.append(item) | |
else: | |
raise ValueError( | |
f"Content items must be str or dict, instead was: {type(item)}" | |
) | |
else: | |
content = message.content | |
formatted_messages.append( | |
{ | |
"role": role, | |
"content": content, | |
} | |
) | |
return system, formatted_messages | |
class ChatPromptAdapter: | |
"""Adapter class to prepare the inputs from Langchain to prompt format | |
that Chat model expects. | |
""" | |
def convert_messages_to_prompt( | |
cls, provider: str, messages: List[BaseMessage] | |
) -> str: | |
if provider == "anthropic": | |
prompt = convert_messages_to_prompt_anthropic(messages=messages) | |
elif provider == "meta": | |
prompt = convert_messages_to_prompt_llama(messages=messages) | |
elif provider == "mistral": | |
prompt = convert_messages_to_prompt_mistral(messages=messages) | |
elif provider == "amazon": | |
prompt = convert_messages_to_prompt_anthropic( | |
messages=messages, | |
human_prompt="\n\nUser:", | |
ai_prompt="\n\nBot:", | |
) | |
else: | |
raise NotImplementedError( | |
f"Provider {provider} model does not support chat." | |
) | |
return prompt | |
def format_messages( | |
cls, provider: str, messages: List[BaseMessage] | |
) -> Tuple[Optional[str], List[Dict]]: | |
if provider == "anthropic": | |
return _format_anthropic_messages(messages) | |
raise NotImplementedError( | |
f"Provider {provider} not supported for format_messages" | |
) | |
_message_type_lookups = { | |
"human": "user", | |
"ai": "assistant", | |
"AIMessageChunk": "assistant", | |
"HumanMessageChunk": "user", | |
"function": "user", | |
} | |
class BedrockChat(BaseChatModel, BedrockBase): | |
"""Chat model that uses the Bedrock API.""" | |
def _llm_type(self) -> str: | |
"""Return type of chat model.""" | |
return "amazon_bedrock_chat" | |
def is_lc_serializable(cls) -> bool: | |
"""Return whether this model can be serialized by Langchain.""" | |
return True | |
def get_lc_namespace(cls) -> List[str]: | |
"""Get the namespace of the langchain object.""" | |
return ["langchain", "chat_models", "bedrock"] | |
def lc_attributes(self) -> Dict[str, Any]: | |
attributes: Dict[str, Any] = {} | |
if self.region_name: | |
attributes["region_name"] = self.region_name | |
return attributes | |
class Config: | |
"""Configuration for this pydantic object.""" | |
extra = Extra.forbid | |
def _stream( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> Iterator[ChatGenerationChunk]: | |
provider = self._get_provider() | |
prompt, system, formatted_messages = None, None, None | |
if provider == "anthropic": | |
system, formatted_messages = ChatPromptAdapter.format_messages( | |
provider, messages | |
) | |
else: | |
prompt = ChatPromptAdapter.convert_messages_to_prompt( | |
provider=provider, messages=messages | |
) | |
for chunk in self._prepare_input_and_invoke_stream( | |
prompt=prompt, | |
system=system, | |
messages=formatted_messages, | |
stop=stop, | |
run_manager=run_manager, | |
**kwargs, | |
): | |
delta = chunk.text | |
yield ChatGenerationChunk(message=AIMessageChunk(content=delta)) | |
def _generate( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> ChatResult: | |
completion = "" | |
llm_output: Dict[str, Any] = {"model_id": self.model_id} | |
if self.streaming: | |
for chunk in self._stream(messages, stop, run_manager, **kwargs): | |
completion += chunk.text | |
else: | |
provider = self._get_provider() | |
prompt, system, formatted_messages = None, None, None | |
params: Dict[str, Any] = {**kwargs} | |
if provider == "anthropic": | |
system, formatted_messages = ChatPromptAdapter.format_messages( | |
provider, messages | |
) | |
else: | |
prompt = ChatPromptAdapter.convert_messages_to_prompt( | |
provider=provider, messages=messages | |
) | |
if stop: | |
params["stop_sequences"] = stop | |
completion, usage_info = self._prepare_input_and_invoke( | |
prompt=prompt, | |
stop=stop, | |
run_manager=run_manager, | |
system=system, | |
messages=formatted_messages, | |
**params, | |
) | |
llm_output["usage"] = usage_info | |
return ChatResult( | |
generations=[ChatGeneration(message=AIMessage(content=completion))], | |
llm_output=llm_output, | |
) | |
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: | |
final_usage: Dict[str, int] = defaultdict(int) | |
final_output = {} | |
for output in llm_outputs: | |
output = output or {} | |
usage = output.get("usage", {}) | |
for token_type, token_count in usage.items(): | |
final_usage[token_type] += token_count | |
final_output.update(output) | |
final_output["usage"] = final_usage | |
return final_output | |
def get_num_tokens(self, text: str) -> int: | |
if self._model_is_anthropic: | |
return get_num_tokens_anthropic(text) | |
else: | |
return super().get_num_tokens(text) | |
def get_token_ids(self, text: str) -> List[int]: | |
if self._model_is_anthropic: | |
return get_token_ids_anthropic(text) | |
else: | |
return super().get_token_ids(text) | |