Spaces:
Runtime error
Runtime error
import json | |
import logging | |
from typing import ( | |
Any, | |
Dict, | |
List, | |
Mapping, | |
Optional, | |
Tuple, | |
) | |
from langchain.schema import ( | |
ChatGeneration, | |
ChatResult, | |
) | |
from langchain_core.callbacks.manager import CallbackManagerForLLMRun | |
from langchain_core.language_models import BaseChatModel | |
from langchain_core.messages import ( | |
AIMessage, | |
BaseMessage, | |
ChatMessage, | |
FunctionMessage, | |
HumanMessage, | |
SystemMessage, | |
) | |
logger = logging.getLogger(__name__) | |
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: | |
role = _dict["role"] | |
if role == "user": | |
return HumanMessage(content=_dict["content"]) | |
elif role == "assistant": | |
# Fix for azure | |
# Also OpenAI returns None for tool invocations | |
content = _dict.get("content") or "" | |
if _dict.get("function_call"): | |
_dict["function_call"]["arguments"] = json.dumps( | |
_dict["function_call"]["arguments"] | |
) | |
additional_kwargs = {"function_call": dict(_dict["function_call"])} | |
else: | |
additional_kwargs = {} | |
return AIMessage(content=content, additional_kwargs=additional_kwargs) | |
elif role == "system": | |
return SystemMessage(content=_dict["content"]) | |
elif role == "function": | |
return FunctionMessage(content=_dict["content"], name=_dict["name"]) | |
else: | |
return ChatMessage(content=_dict["content"], role=role) | |
def _convert_message_to_dict(message: BaseMessage) -> dict: | |
if isinstance(message, ChatMessage): | |
message_dict = {"role": message.role, "content": message.content} | |
elif isinstance(message, HumanMessage): | |
message_dict = {"role": "user", "content": message.content} | |
elif isinstance(message, AIMessage): | |
message_dict = {"role": "assistant", "content": message.content} | |
if "function_call" in message.additional_kwargs: | |
message_dict["function_call"] = message.additional_kwargs["function_call"] | |
elif isinstance(message, SystemMessage): | |
message_dict = {"role": "system", "content": message.content} | |
elif isinstance(message, FunctionMessage): | |
message_dict = { | |
"role": "function", | |
"content": message.content, | |
"name": message.name, | |
} | |
else: | |
raise ValueError(f"Got unknown type {message}") | |
if "name" in message.additional_kwargs: | |
message_dict["name"] = message.additional_kwargs["name"] | |
return message_dict | |
class ChatLlamaAPI(BaseChatModel): | |
"""Chat model using the Llama API.""" | |
client: Any #: :meta private: | |
def _generate( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> ChatResult: | |
message_dicts, params = self._create_message_dicts(messages, stop) | |
_params = {"messages": message_dicts} | |
final_params = {**params, **kwargs, **_params} | |
response = self.client.run(final_params).json() | |
return self._create_chat_result(response) | |
def _create_message_dicts( | |
self, messages: List[BaseMessage], stop: Optional[List[str]] | |
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: | |
params = dict(self._client_params) | |
if stop is not None: | |
if "stop" in params: | |
raise ValueError("`stop` found in both the input and default params.") | |
params["stop"] = stop | |
message_dicts = [_convert_message_to_dict(m) for m in messages] | |
return message_dicts, params | |
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: | |
generations = [] | |
for res in response["choices"]: | |
message = _convert_dict_to_message(res["message"]) | |
gen = ChatGeneration( | |
message=message, | |
generation_info=dict(finish_reason=res.get("finish_reason")), | |
) | |
generations.append(gen) | |
return ChatResult(generations=generations) | |
def _client_params(self) -> Mapping[str, Any]: | |
"""Get the parameters used for the client.""" | |
return {} | |
def _llm_type(self) -> str: | |
"""Return type of chat model.""" | |
return "llama-api" | |