Spaces:
Runtime error
Runtime error
import json | |
import logging | |
from typing import Any, Dict, Iterator, List, Mapping, Optional, Union | |
import requests | |
from langchain_core.callbacks import CallbackManagerForLLMRun | |
from langchain_core.language_models.chat_models import ( | |
BaseChatModel, | |
generate_from_stream, | |
) | |
from langchain_core.messages import ( | |
AIMessage, | |
AIMessageChunk, | |
BaseMessage, | |
BaseMessageChunk, | |
ChatMessage, | |
ChatMessageChunk, | |
HumanMessage, | |
HumanMessageChunk, | |
) | |
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult | |
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator | |
from langchain_core.utils import ( | |
convert_to_secret_str, | |
get_from_dict_or_env, | |
) | |
logger = logging.getLogger(__name__) | |
DEFAULT_API_BASE = "https://api.coze.com" | |
def _convert_message_to_dict(message: BaseMessage) -> dict: | |
message_dict: Dict[str, Any] | |
if isinstance(message, HumanMessage): | |
message_dict = { | |
"role": "user", | |
"content": message.content, | |
"content_type": "text", | |
} | |
else: | |
message_dict = { | |
"role": "assistant", | |
"content": message.content, | |
"content_type": "text", | |
} | |
return message_dict | |
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> Union[BaseMessage, None]: | |
msg_type = _dict["type"] | |
if msg_type != "answer": | |
return None | |
role = _dict["role"] | |
if role == "user": | |
return HumanMessage(content=_dict["content"]) | |
elif role == "assistant": | |
return AIMessage(content=_dict.get("content", "") or "") | |
else: | |
return ChatMessage(content=_dict["content"], role=role) | |
def _convert_delta_to_message_chunk(_dict: Mapping[str, Any]) -> BaseMessageChunk: | |
role = _dict.get("role") | |
content = _dict.get("content") or "" | |
if role == "user": | |
return HumanMessageChunk(content=content) | |
elif role == "assistant": | |
return AIMessageChunk(content=content) | |
else: | |
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type] | |
class ChatCoze(BaseChatModel): | |
"""ChatCoze chat models API by coze.com | |
For more information, see https://www.coze.com/open/docs/chat | |
""" | |
def lc_secrets(self) -> Dict[str, str]: | |
return { | |
"coze_api_key": "COZE_API_KEY", | |
} | |
def lc_serializable(self) -> bool: | |
return True | |
coze_api_base: str = Field(default=DEFAULT_API_BASE) | |
"""Coze custom endpoints""" | |
coze_api_key: Optional[SecretStr] = None | |
"""Coze API Key""" | |
request_timeout: int = Field(default=60, alias="timeout") | |
"""request timeout for chat http requests""" | |
bot_id: str = Field(default="") | |
"""The ID of the bot that the API interacts with.""" | |
conversation_id: str = Field(default="") | |
"""Indicate which conversation the dialog is taking place in. If there is no need to | |
distinguish the context of the conversation(just a question and answer), skip this | |
parameter. It will be generated by the system.""" | |
user: str = Field(default="") | |
"""The user who calls the API to chat with the bot.""" | |
streaming: bool = False | |
"""Whether to stream the response to the client. | |
false: if no value is specified or set to false, a non-streaming response is | |
returned. "Non-streaming response" means that all responses will be returned at once | |
after they are all ready, and the client does not need to concatenate the content. | |
true: set to true, partial message deltas will be sent . | |
"Streaming response" will provide real-time response of the model to the client, and | |
the client needs to assemble the final reply based on the type of message. """ | |
class Config: | |
"""Configuration for this pydantic object.""" | |
allow_population_by_field_name = True | |
def validate_environment(cls, values: Dict) -> Dict: | |
values["coze_api_base"] = get_from_dict_or_env( | |
values, | |
"coze_api_base", | |
"COZE_API_BASE", | |
DEFAULT_API_BASE, | |
) | |
values["coze_api_key"] = convert_to_secret_str( | |
get_from_dict_or_env( | |
values, | |
"coze_api_key", | |
"COZE_API_KEY", | |
) | |
) | |
return values | |
def _default_params(self) -> Dict[str, Any]: | |
"""Get the default parameters for calling Coze API.""" | |
return { | |
"bot_id": self.bot_id, | |
"conversation_id": self.conversation_id, | |
"user": self.user, | |
"streaming": self.streaming, | |
} | |
def _generate( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> ChatResult: | |
if self.streaming: | |
stream_iter = self._stream( | |
messages=messages, stop=stop, run_manager=run_manager, **kwargs | |
) | |
return generate_from_stream(stream_iter) | |
r = self._chat(messages, **kwargs) | |
res = r.json() | |
if res["code"] != 0: | |
raise ValueError( | |
f"Error from Coze api response: {res['code']}: {res['msg']}, " | |
f"logid: {r.headers.get('X-Tt-Logid')}" | |
) | |
return self._create_chat_result(res.get("messages") or []) | |
def _stream( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> Iterator[ChatGenerationChunk]: | |
res = self._chat(messages, **kwargs) | |
for chunk in res.iter_lines(): | |
chunk = chunk.decode("utf-8").strip("\r\n") | |
parts = chunk.split("data:", 1) | |
chunk = parts[1] if len(parts) > 1 else None | |
if chunk is None: | |
continue | |
response = json.loads(chunk) | |
if response["event"] == "done": | |
break | |
elif ( | |
response["event"] != "message" | |
or response["message"]["type"] != "answer" | |
): | |
continue | |
chunk = _convert_delta_to_message_chunk(response["message"]) | |
cg_chunk = ChatGenerationChunk(message=chunk) | |
if run_manager: | |
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk) | |
yield cg_chunk | |
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response: | |
parameters = {**self._default_params, **kwargs} | |
query = "" | |
chat_history = [] | |
for msg in messages: | |
if isinstance(msg, HumanMessage): | |
query = f"{msg.content}" # overwrite, to get last user message as query | |
chat_history.append(_convert_message_to_dict(msg)) | |
conversation_id = parameters.pop("conversation_id") | |
bot_id = parameters.pop("bot_id") | |
user = parameters.pop("user") | |
streaming = parameters.pop("streaming") | |
payload = { | |
"conversation_id": conversation_id, | |
"bot_id": bot_id, | |
"user": user, | |
"query": query, | |
"stream": streaming, | |
} | |
if chat_history: | |
payload["chat_history"] = chat_history | |
url = self.coze_api_base + "/open_api/v2/chat" | |
api_key = "" | |
if self.coze_api_key: | |
api_key = self.coze_api_key.get_secret_value() | |
res = requests.post( | |
url=url, | |
timeout=self.request_timeout, | |
headers={ | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {api_key}", | |
}, | |
json=payload, | |
stream=streaming, | |
) | |
if res.status_code != 200: | |
logid = res.headers.get("X-Tt-Logid") | |
raise ValueError(f"Error from Coze api response: {res}, logid: {logid}") | |
return res | |
def _create_chat_result(self, messages: List[Mapping[str, Any]]) -> ChatResult: | |
generations = [] | |
for c in messages: | |
msg = _convert_dict_to_message(c) | |
if msg: | |
generations.append(ChatGeneration(message=msg)) | |
llm_output = {"token_usage": "", "model": ""} | |
return ChatResult(generations=generations, llm_output=llm_output) | |
def _llm_type(self) -> str: | |
return "coze-chat" | |