|
|
|
|
|
|
|
|
|
|
|
import asyncio |
|
import traceback |
|
import aiohttp |
|
import json |
|
import copy |
|
|
|
from typing import List, Any, AsyncGenerator |
|
from dataclasses import dataclass |
|
|
|
from cozepy import ChatEventType, Message, TokenAuth, AsyncCoze, ChatEvent, Chat |
|
|
|
from ten import ( |
|
AudioFrame, |
|
VideoFrame, |
|
AsyncTenEnv, |
|
Cmd, |
|
StatusCode, |
|
CmdResult, |
|
Data, |
|
) |
|
|
|
from ten_ai_base.config import BaseConfig |
|
from ten_ai_base.chat_memory import ChatMemory |
|
from ten_ai_base.types import ( |
|
LLMChatCompletionUserMessageParam, |
|
LLMCallCompletionArgs, |
|
LLMDataCompletionArgs, |
|
LLMToolMetadata, |
|
) |
|
from ten_ai_base.llm import ( |
|
AsyncLLMBaseExtension, |
|
) |
|
|
|
CMD_IN_FLUSH = "flush" |
|
CMD_IN_ON_USER_JOINED = "on_user_joined" |
|
CMD_IN_ON_USER_LEFT = "on_user_left" |
|
CMD_OUT_FLUSH = "flush" |
|
CMD_OUT_TOOL_CALL = "tool_call" |
|
|
|
DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL = "is_final" |
|
DATA_IN_TEXT_DATA_PROPERTY_TEXT = "text" |
|
|
|
DATA_OUT_TEXT_DATA_PROPERTY_TEXT = "text" |
|
DATA_OUT_TEXT_DATA_PROPERTY_END_OF_SEGMENT = "end_of_segment" |
|
|
|
CMD_PROPERTY_RESULT = "tool_result" |
|
|
|
|
|
def is_punctuation(char): |
|
if char in [",", ",", ".", "。", "?", "?", "!", "!"]: |
|
return True |
|
return False |
|
|
|
|
|
def parse_sentences(sentence_fragment, content): |
|
sentences = [] |
|
current_sentence = sentence_fragment |
|
for char in content: |
|
current_sentence += char |
|
if is_punctuation(char): |
|
stripped_sentence = current_sentence |
|
if any(c.isalnum() for c in stripped_sentence): |
|
sentences.append(stripped_sentence) |
|
current_sentence = "" |
|
|
|
remain = current_sentence |
|
return sentences, remain |
|
|
|
|
|
@dataclass |
|
class CozeConfig(BaseConfig): |
|
base_url: str = "https://api.acoze.com" |
|
bot_id: str = "" |
|
token: str = "" |
|
user_id: str = "TenAgent" |
|
greeting: str = "" |
|
max_history: int = 32 |
|
|
|
|
|
class AsyncCozeExtension(AsyncLLMBaseExtension): |
|
config: CozeConfig = None |
|
sentence_fragment: str = "" |
|
ten_env: AsyncTenEnv = None |
|
loop: asyncio.AbstractEventLoop = None |
|
stopped: bool = False |
|
users_count = 0 |
|
memory: ChatMemory = None |
|
|
|
acoze: AsyncCoze = None |
|
|
|
|
|
async def on_init(self, ten_env: AsyncTenEnv) -> None: |
|
await super().on_init(ten_env) |
|
ten_env.log_debug("on_init") |
|
|
|
async def on_start(self, ten_env: AsyncTenEnv) -> None: |
|
await super().on_start(ten_env) |
|
ten_env.log_debug("on_start") |
|
|
|
self.loop = asyncio.get_event_loop() |
|
|
|
self.config = await CozeConfig.create_async(ten_env=ten_env) |
|
ten_env.log_info(f"config: {self.config}") |
|
|
|
if not self.config.bot_id or not self.config.token: |
|
ten_env.log_error("Missing required configuration") |
|
return |
|
|
|
self.memory = ChatMemory(self.config.max_history) |
|
try: |
|
self.acoze = AsyncCoze( |
|
auth=TokenAuth(token=self.config.token), base_url=self.config.base_url |
|
) |
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
ten_env.log_error(f"Failed to create conversation {e}") |
|
|
|
self.ten_env = ten_env |
|
|
|
async def on_stop(self, ten_env: AsyncTenEnv) -> None: |
|
await super().on_stop(ten_env) |
|
ten_env.log_debug("on_stop") |
|
|
|
self.stopped = True |
|
|
|
async def on_deinit(self, ten_env: AsyncTenEnv) -> None: |
|
await super().on_deinit(ten_env) |
|
ten_env.log_debug("on_deinit") |
|
|
|
async def on_cmd(self, ten_env: AsyncTenEnv, cmd: Cmd) -> None: |
|
cmd_name = cmd.get_name() |
|
ten_env.log_debug("on_cmd name {}".format(cmd_name)) |
|
|
|
status = StatusCode.OK |
|
detail = "success" |
|
|
|
if cmd_name == CMD_IN_FLUSH: |
|
await self.flush_input_items(ten_env) |
|
await ten_env.send_cmd(Cmd.create(CMD_OUT_FLUSH)) |
|
ten_env.log_info("on flush") |
|
elif cmd_name == CMD_IN_ON_USER_JOINED: |
|
self.users_count += 1 |
|
|
|
if self.config.greeting and self.users_count == 1: |
|
self.send_text_output(ten_env, self.config.greeting, True) |
|
elif cmd_name == CMD_IN_ON_USER_LEFT: |
|
self.users_count -= 1 |
|
else: |
|
await super().on_cmd(ten_env, cmd) |
|
return |
|
|
|
cmd_result = CmdResult.create(status) |
|
cmd_result.set_property_string("detail", detail) |
|
await ten_env.return_result(cmd_result, cmd) |
|
|
|
async def on_call_chat_completion( |
|
self, ten_env: AsyncTenEnv, **kargs: LLMCallCompletionArgs |
|
) -> any: |
|
raise RuntimeError("Not implemented") |
|
|
|
async def on_data_chat_completion( |
|
self, ten_env: AsyncTenEnv, **kargs: LLMDataCompletionArgs |
|
) -> None: |
|
if not self.acoze: |
|
await self._send_text( |
|
"Coze is not connected. Please check your configuration.", True |
|
) |
|
return |
|
|
|
input_messages: LLMChatCompletionUserMessageParam = kargs.get("messages", []) |
|
messages = copy.copy(self.memory.get()) |
|
if not input_messages: |
|
ten_env.log_warn("No message in data") |
|
else: |
|
messages.extend(input_messages) |
|
for i in input_messages: |
|
self.memory.put(i) |
|
|
|
total_output = "" |
|
sentence_fragment = "" |
|
calls = {} |
|
|
|
sentences = [] |
|
self.ten_env.log_info(f"messages: {messages}") |
|
response = self._stream_chat(messages=messages) |
|
async for message in response: |
|
self.ten_env.log_info(f"content: {message}") |
|
try: |
|
if message.event == ChatEventType.CONVERSATION_MESSAGE_DELTA: |
|
total_output += message.message.content |
|
sentences, sentence_fragment = parse_sentences( |
|
sentence_fragment, message.message.content |
|
) |
|
for s in sentences: |
|
await self._send_text(s, False) |
|
elif message.event == ChatEventType.CONVERSATION_MESSAGE_COMPLETED: |
|
if sentence_fragment: |
|
await self._send_text(sentence_fragment, True) |
|
else: |
|
await self._send_text("", True) |
|
elif message.event == ChatEventType.CONVERSATION_CHAT_FAILED: |
|
last_error = message.chat.last_error |
|
if last_error and last_error.code == 4011: |
|
await self._send_text( |
|
"The Coze token has been depleted. Please check your token usage.", |
|
True, |
|
) |
|
else: |
|
await self._send_text(last_error.msg, True) |
|
except Exception as e: |
|
self.ten_env.log_error(f"Failed to parse response: {message} {e}") |
|
traceback.print_exc() |
|
|
|
self.memory.put({"role": "assistant", "content": total_output}) |
|
self.ten_env.log_info(f"total_output: {total_output} {calls}") |
|
|
|
async def on_tools_update( |
|
self, ten_env: AsyncTenEnv, tool: LLMToolMetadata |
|
) -> None: |
|
|
|
return await super().on_tools_update(ten_env, tool) |
|
|
|
async def on_data(self, ten_env: AsyncTenEnv, data: Data) -> None: |
|
data_name = data.get_name() |
|
ten_env.log_info("on_data name {}".format(data_name)) |
|
|
|
is_final = False |
|
input_text = "" |
|
try: |
|
is_final = data.get_property_bool(DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL) |
|
except Exception as err: |
|
ten_env.log_info( |
|
f"GetProperty optional {DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL} failed, err: {err}" |
|
) |
|
|
|
try: |
|
input_text = data.get_property_string(DATA_IN_TEXT_DATA_PROPERTY_TEXT) |
|
except Exception as err: |
|
ten_env.log_info( |
|
f"GetProperty optional {DATA_IN_TEXT_DATA_PROPERTY_TEXT} failed, err: {err}" |
|
) |
|
|
|
if not is_final: |
|
ten_env.log_info("ignore non-final input") |
|
return |
|
if not input_text: |
|
ten_env.log_info("ignore empty text") |
|
return |
|
|
|
ten_env.log_info(f"OnData input text: [{input_text}]") |
|
|
|
|
|
message = LLMChatCompletionUserMessageParam(role="user", content=input_text) |
|
await self.queue_input_item(False, messages=[message]) |
|
|
|
async def on_audio_frame( |
|
self, ten_env: AsyncTenEnv, audio_frame: AudioFrame |
|
) -> None: |
|
pass |
|
|
|
async def on_video_frame( |
|
self, ten_env: AsyncTenEnv, video_frame: VideoFrame |
|
) -> None: |
|
pass |
|
|
|
async def _send_text(self, text: str, end_of_segment: bool) -> None: |
|
data = Data.create("text_data") |
|
data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, text) |
|
data.set_property_bool( |
|
DATA_OUT_TEXT_DATA_PROPERTY_END_OF_SEGMENT, end_of_segment |
|
) |
|
asyncio.create_task(self.ten_env.send_data(data)) |
|
|
|
async def _stream_chat( |
|
self, messages: List[Any] |
|
) -> AsyncGenerator[ChatEvent, None]: |
|
additionals = [] |
|
for m in messages: |
|
if m["role"] == "user": |
|
additionals.append( |
|
Message.build_user_question_text(m["content"]).model_dump() |
|
) |
|
elif m["role"] == "assistant": |
|
additionals.append( |
|
Message.build_assistant_answer(m["content"]).model_dump() |
|
) |
|
|
|
def chat_stream_handler(event: str, event_data: Any) -> ChatEvent: |
|
if event == ChatEventType.DONE: |
|
raise StopAsyncIteration |
|
elif event == ChatEventType.ERROR: |
|
raise RuntimeError(f"error event: {event_data}") |
|
elif event in [ |
|
ChatEventType.CONVERSATION_MESSAGE_DELTA, |
|
ChatEventType.CONVERSATION_MESSAGE_COMPLETED, |
|
]: |
|
return ChatEvent( |
|
event=event, message=Message.model_validate_json(event_data) |
|
) |
|
elif event in [ |
|
ChatEventType.CONVERSATION_CHAT_CREATED, |
|
ChatEventType.CONVERSATION_CHAT_IN_PROGRESS, |
|
ChatEventType.CONVERSATION_CHAT_COMPLETED, |
|
ChatEventType.CONVERSATION_CHAT_FAILED, |
|
ChatEventType.CONVERSATION_CHAT_REQUIRES_ACTION, |
|
]: |
|
return ChatEvent(event=event, chat=Chat.model_validate_json(event_data)) |
|
else: |
|
raise ValueError(f"invalid chat.event: {event}, {event_data}") |
|
|
|
async with aiohttp.ClientSession() as session: |
|
try: |
|
url = f"{self.config.base_url}/v3/chat" |
|
headers = { |
|
"Authorization": f"Bearer {self.config.token}", |
|
} |
|
params = { |
|
"bot_id": self.config.bot_id, |
|
"user_id": self.config.user_id, |
|
"additional_messages": additionals, |
|
"stream": True, |
|
"auto_save_history": True, |
|
|
|
} |
|
event = "" |
|
async with session.post(url, json=params, headers=headers) as response: |
|
async for line in response.content: |
|
if line: |
|
try: |
|
self.ten_env.log_info(f"line: {line}") |
|
decoded_line = line.decode("utf-8").strip() |
|
if decoded_line: |
|
if decoded_line.startswith("data:"): |
|
data = decoded_line[5:].strip() |
|
yield chat_stream_handler( |
|
event=event, event_data=data.strip() |
|
) |
|
elif decoded_line.startswith("event:"): |
|
event = decoded_line[6:] |
|
self.ten_env.log_info(f"event: {event}") |
|
if event == "done": |
|
break |
|
else: |
|
result = json.loads(decoded_line) |
|
code = result.get("code", 0) |
|
if code == 4000: |
|
await self._send_text( |
|
"Coze bot is not published.", True |
|
) |
|
else: |
|
self.ten_env.log_error( |
|
f"Failed to stream chat: {result['code']}" |
|
) |
|
await self._send_text( |
|
"Coze bot is not connected. Please check your configuration.", |
|
True, |
|
) |
|
except Exception as e: |
|
self.ten_env.log_error(f"Failed to stream chat: {e}") |
|
except Exception as e: |
|
traceback.print_exc() |
|
self.ten_env.log_error(f"Failed to stream chat: {e}") |
|
finally: |
|
await session.close() |
|
|