Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import importlib | |
from typing import ( | |
Any, | |
AsyncIterator, | |
Dict, | |
Iterable, | |
List, | |
Mapping, | |
Sequence, | |
Union, | |
overload, | |
) | |
from langchain_core.chat_sessions import ChatSession | |
from langchain_core.messages import ( | |
AIMessage, | |
AIMessageChunk, | |
BaseMessage, | |
BaseMessageChunk, | |
ChatMessage, | |
FunctionMessage, | |
HumanMessage, | |
SystemMessage, | |
ToolMessage, | |
) | |
from langchain_core.pydantic_v1 import BaseModel | |
from typing_extensions import Literal | |
async def aenumerate( | |
iterable: AsyncIterator[Any], start: int = 0 | |
) -> AsyncIterator[tuple[int, Any]]: | |
"""Async version of enumerate function.""" | |
i = start | |
async for x in iterable: | |
yield i, x | |
i += 1 | |
class IndexableBaseModel(BaseModel): | |
"""Allows a BaseModel to return its fields by string variable indexing.""" | |
def __getitem__(self, item: str) -> Any: | |
return getattr(self, item) | |
class Choice(IndexableBaseModel): | |
"""Choice.""" | |
message: dict | |
class ChatCompletions(IndexableBaseModel): | |
"""Chat completions.""" | |
choices: List[Choice] | |
class ChoiceChunk(IndexableBaseModel): | |
"""Choice chunk.""" | |
delta: dict | |
class ChatCompletionChunk(IndexableBaseModel): | |
"""Chat completion chunk.""" | |
choices: List[ChoiceChunk] | |
def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: | |
"""Convert a dictionary to a LangChain message. | |
Args: | |
_dict: The dictionary. | |
Returns: | |
The LangChain message. | |
""" | |
role = _dict.get("role") | |
if role == "user": | |
return HumanMessage(content=_dict.get("content", "")) | |
elif role == "assistant": | |
# Fix for azure | |
# Also OpenAI returns None for tool invocations | |
content = _dict.get("content", "") or "" | |
additional_kwargs: Dict = {} | |
if function_call := _dict.get("function_call"): | |
additional_kwargs["function_call"] = dict(function_call) | |
if tool_calls := _dict.get("tool_calls"): | |
additional_kwargs["tool_calls"] = tool_calls | |
return AIMessage(content=content, additional_kwargs=additional_kwargs) | |
elif role == "system": | |
return SystemMessage(content=_dict.get("content", "")) | |
elif role == "function": | |
return FunctionMessage(content=_dict.get("content", ""), name=_dict.get("name")) # type: ignore[arg-type] | |
elif role == "tool": | |
additional_kwargs = {} | |
if "name" in _dict: | |
additional_kwargs["name"] = _dict["name"] | |
return ToolMessage( | |
content=_dict.get("content", ""), | |
tool_call_id=_dict.get("tool_call_id"), # type: ignore[arg-type] | |
additional_kwargs=additional_kwargs, | |
) | |
else: | |
return ChatMessage(content=_dict.get("content", ""), role=role) # type: ignore[arg-type] | |
def convert_message_to_dict(message: BaseMessage) -> dict: | |
"""Convert a LangChain message to a dictionary. | |
Args: | |
message: The LangChain message. | |
Returns: | |
The dictionary. | |
""" | |
message_dict: Dict[str, Any] | |
if isinstance(message, ChatMessage): | |
message_dict = {"role": message.role, "content": message.content} | |
elif isinstance(message, HumanMessage): | |
message_dict = {"role": "user", "content": message.content} | |
elif isinstance(message, AIMessage): | |
message_dict = {"role": "assistant", "content": message.content} | |
if "function_call" in message.additional_kwargs: | |
message_dict["function_call"] = message.additional_kwargs["function_call"] | |
# If function call only, content is None not empty string | |
if message_dict["content"] == "": | |
message_dict["content"] = None | |
if "tool_calls" in message.additional_kwargs: | |
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"] | |
# If tool calls only, content is None not empty string | |
if message_dict["content"] == "": | |
message_dict["content"] = None | |
elif isinstance(message, SystemMessage): | |
message_dict = {"role": "system", "content": message.content} | |
elif isinstance(message, FunctionMessage): | |
message_dict = { | |
"role": "function", | |
"content": message.content, | |
"name": message.name, | |
} | |
elif isinstance(message, ToolMessage): | |
message_dict = { | |
"role": "tool", | |
"content": message.content, | |
"tool_call_id": message.tool_call_id, | |
} | |
else: | |
raise TypeError(f"Got unknown type {message}") | |
if "name" in message.additional_kwargs: | |
message_dict["name"] = message.additional_kwargs["name"] | |
return message_dict | |
def convert_openai_messages(messages: Sequence[Dict[str, Any]]) -> List[BaseMessage]: | |
"""Convert dictionaries representing OpenAI messages to LangChain format. | |
Args: | |
messages: List of dictionaries representing OpenAI messages | |
Returns: | |
List of LangChain BaseMessage objects. | |
""" | |
return [convert_dict_to_message(m) for m in messages] | |
def _convert_message_chunk(chunk: BaseMessageChunk, i: int) -> dict: | |
_dict: Dict[str, Any] = {} | |
if isinstance(chunk, AIMessageChunk): | |
if i == 0: | |
# Only shows up in the first chunk | |
_dict["role"] = "assistant" | |
if "function_call" in chunk.additional_kwargs: | |
_dict["function_call"] = chunk.additional_kwargs["function_call"] | |
# If the first chunk is a function call, the content is not empty string, | |
# not missing, but None. | |
if i == 0: | |
_dict["content"] = None | |
else: | |
_dict["content"] = chunk.content | |
else: | |
raise ValueError(f"Got unexpected streaming chunk type: {type(chunk)}") | |
# This only happens at the end of streams, and OpenAI returns as empty dict | |
if _dict == {"content": ""}: | |
_dict = {} | |
return _dict | |
def _convert_message_chunk_to_delta(chunk: BaseMessageChunk, i: int) -> Dict[str, Any]: | |
_dict = _convert_message_chunk(chunk, i) | |
return {"choices": [{"delta": _dict}]} | |
class ChatCompletion: | |
"""Chat completion.""" | |
def create( | |
messages: Sequence[Dict[str, Any]], | |
*, | |
provider: str = "ChatOpenAI", | |
stream: Literal[False] = False, | |
**kwargs: Any, | |
) -> dict: | |
... | |
def create( | |
messages: Sequence[Dict[str, Any]], | |
*, | |
provider: str = "ChatOpenAI", | |
stream: Literal[True], | |
**kwargs: Any, | |
) -> Iterable: | |
... | |
def create( | |
messages: Sequence[Dict[str, Any]], | |
*, | |
provider: str = "ChatOpenAI", | |
stream: bool = False, | |
**kwargs: Any, | |
) -> Union[dict, Iterable]: | |
models = importlib.import_module("langchain.chat_models") | |
model_cls = getattr(models, provider) | |
model_config = model_cls(**kwargs) | |
converted_messages = convert_openai_messages(messages) | |
if not stream: | |
result = model_config.invoke(converted_messages) | |
return {"choices": [{"message": convert_message_to_dict(result)}]} | |
else: | |
return ( | |
_convert_message_chunk_to_delta(c, i) | |
for i, c in enumerate(model_config.stream(converted_messages)) | |
) | |
async def acreate( | |
messages: Sequence[Dict[str, Any]], | |
*, | |
provider: str = "ChatOpenAI", | |
stream: Literal[False] = False, | |
**kwargs: Any, | |
) -> dict: | |
... | |
async def acreate( | |
messages: Sequence[Dict[str, Any]], | |
*, | |
provider: str = "ChatOpenAI", | |
stream: Literal[True], | |
**kwargs: Any, | |
) -> AsyncIterator: | |
... | |
async def acreate( | |
messages: Sequence[Dict[str, Any]], | |
*, | |
provider: str = "ChatOpenAI", | |
stream: bool = False, | |
**kwargs: Any, | |
) -> Union[dict, AsyncIterator]: | |
models = importlib.import_module("langchain.chat_models") | |
model_cls = getattr(models, provider) | |
model_config = model_cls(**kwargs) | |
converted_messages = convert_openai_messages(messages) | |
if not stream: | |
result = await model_config.ainvoke(converted_messages) | |
return {"choices": [{"message": convert_message_to_dict(result)}]} | |
else: | |
return ( | |
_convert_message_chunk_to_delta(c, i) | |
async for i, c in aenumerate(model_config.astream(converted_messages)) | |
) | |
def _has_assistant_message(session: ChatSession) -> bool: | |
"""Check if chat session has an assistant message.""" | |
return any([isinstance(m, AIMessage) for m in session["messages"]]) | |
def convert_messages_for_finetuning( | |
sessions: Iterable[ChatSession], | |
) -> List[List[dict]]: | |
"""Convert messages to a list of lists of dictionaries for fine-tuning. | |
Args: | |
sessions: The chat sessions. | |
Returns: | |
The list of lists of dictionaries. | |
""" | |
return [ | |
[convert_message_to_dict(s) for s in session["messages"]] | |
for session in sessions | |
if _has_assistant_message(session) | |
] | |
class Completions: | |
"""Completions.""" | |
def create( | |
messages: Sequence[Dict[str, Any]], | |
*, | |
provider: str = "ChatOpenAI", | |
stream: Literal[False] = False, | |
**kwargs: Any, | |
) -> ChatCompletions: | |
... | |
def create( | |
messages: Sequence[Dict[str, Any]], | |
*, | |
provider: str = "ChatOpenAI", | |
stream: Literal[True], | |
**kwargs: Any, | |
) -> Iterable: | |
... | |
def create( | |
messages: Sequence[Dict[str, Any]], | |
*, | |
provider: str = "ChatOpenAI", | |
stream: bool = False, | |
**kwargs: Any, | |
) -> Union[ChatCompletions, Iterable]: | |
models = importlib.import_module("langchain.chat_models") | |
model_cls = getattr(models, provider) | |
model_config = model_cls(**kwargs) | |
converted_messages = convert_openai_messages(messages) | |
if not stream: | |
result = model_config.invoke(converted_messages) | |
return ChatCompletions( | |
choices=[Choice(message=convert_message_to_dict(result))] | |
) | |
else: | |
return ( | |
ChatCompletionChunk( | |
choices=[ChoiceChunk(delta=_convert_message_chunk(c, i))] | |
) | |
for i, c in enumerate(model_config.stream(converted_messages)) | |
) | |
async def acreate( | |
messages: Sequence[Dict[str, Any]], | |
*, | |
provider: str = "ChatOpenAI", | |
stream: Literal[False] = False, | |
**kwargs: Any, | |
) -> ChatCompletions: | |
... | |
async def acreate( | |
messages: Sequence[Dict[str, Any]], | |
*, | |
provider: str = "ChatOpenAI", | |
stream: Literal[True], | |
**kwargs: Any, | |
) -> AsyncIterator: | |
... | |
async def acreate( | |
messages: Sequence[Dict[str, Any]], | |
*, | |
provider: str = "ChatOpenAI", | |
stream: bool = False, | |
**kwargs: Any, | |
) -> Union[ChatCompletions, AsyncIterator]: | |
models = importlib.import_module("langchain.chat_models") | |
model_cls = getattr(models, provider) | |
model_config = model_cls(**kwargs) | |
converted_messages = convert_openai_messages(messages) | |
if not stream: | |
result = await model_config.ainvoke(converted_messages) | |
return ChatCompletions( | |
choices=[Choice(message=convert_message_to_dict(result))] | |
) | |
else: | |
return ( | |
ChatCompletionChunk( | |
choices=[ChoiceChunk(delta=_convert_message_chunk(c, i))] | |
) | |
async for i, c in aenumerate(model_config.astream(converted_messages)) | |
) | |
class Chat: | |
"""Chat.""" | |
def __init__(self) -> None: | |
self.completions = Completions() | |
chat = Chat() | |