|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import asyncio |
|
import json |
|
import time |
|
import traceback |
|
from typing import Iterable |
|
import uuid |
|
|
|
from ten.async_ten_env import AsyncTenEnv |
|
from ten_ai_base.const import CMD_PROPERTY_RESULT, CMD_TOOL_CALL, CONTENT_DATA_OUT_NAME, DATA_OUT_PROPERTY_END_OF_SEGMENT, DATA_OUT_PROPERTY_TEXT |
|
from ten_ai_base.helper import ( |
|
AsyncEventEmitter, |
|
get_property_bool, |
|
get_property_string, |
|
) |
|
from ten_ai_base.types import ( |
|
LLMCallCompletionArgs, |
|
LLMChatCompletionContentPartParam, |
|
LLMChatCompletionUserMessageParam, |
|
LLMChatCompletionMessageParam, |
|
LLMDataCompletionArgs, |
|
LLMToolMetadata, |
|
LLMToolResult, |
|
) |
|
from ten_ai_base.llm import AsyncLLMBaseExtension |
|
|
|
from .helper import parse_sentences |
|
from .openai import OpenAIChatGPT, OpenAIChatGPTConfig |
|
from ten import ( |
|
Cmd, |
|
StatusCode, |
|
CmdResult, |
|
Data, |
|
) |
|
|
|
CMD_IN_FLUSH = "flush" |
|
CMD_IN_ON_USER_JOINED = "on_user_joined" |
|
CMD_IN_ON_USER_LEFT = "on_user_left" |
|
CMD_OUT_FLUSH = "flush" |
|
DATA_IN_TEXT_DATA_PROPERTY_TEXT = "text" |
|
DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL = "is_final" |
|
DATA_OUT_TEXT_DATA_PROPERTY_TEXT = "text" |
|
DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT = "end_of_segment" |
|
|
|
|
|
class OpenAIChatGPTExtension(AsyncLLMBaseExtension): |
|
def __init__(self, name: str): |
|
super().__init__(name) |
|
self.memory = [] |
|
self.memory_cache = [] |
|
self.config = None |
|
self.client = None |
|
self.sentence_fragment = "" |
|
self.tool_task_future: asyncio.Future | None = None |
|
self.users_count = 0 |
|
self.last_reasoning_ts = 0 |
|
|
|
async def on_init(self, async_ten_env: AsyncTenEnv) -> None: |
|
async_ten_env.log_info("on_init") |
|
await super().on_init(async_ten_env) |
|
|
|
async def on_start(self, async_ten_env: AsyncTenEnv) -> None: |
|
async_ten_env.log_info("on_start") |
|
await super().on_start(async_ten_env) |
|
|
|
self.config = await OpenAIChatGPTConfig.create_async(ten_env=async_ten_env) |
|
|
|
|
|
if not self.config.api_key: |
|
async_ten_env.log_info("API key is missing, exiting on_start") |
|
return |
|
|
|
|
|
try: |
|
self.client = OpenAIChatGPT(async_ten_env, self.config) |
|
async_ten_env.log_info( |
|
f"initialized with max_tokens: {self.config.max_tokens}, model: {self.config.model}, vendor: {self.config.vendor}" |
|
) |
|
except Exception as err: |
|
async_ten_env.log_info(f"Failed to initialize OpenAIChatGPT: {err}") |
|
|
|
async def on_stop(self, async_ten_env: AsyncTenEnv) -> None: |
|
async_ten_env.log_info("on_stop") |
|
await super().on_stop(async_ten_env) |
|
|
|
async def on_deinit(self, async_ten_env: AsyncTenEnv) -> None: |
|
async_ten_env.log_info("on_deinit") |
|
await super().on_deinit(async_ten_env) |
|
|
|
async def on_cmd(self, async_ten_env: AsyncTenEnv, cmd: Cmd) -> None: |
|
cmd_name = cmd.get_name() |
|
async_ten_env.log_info(f"on_cmd name: {cmd_name}") |
|
|
|
if cmd_name == CMD_IN_FLUSH: |
|
await self.flush_input_items(async_ten_env) |
|
await async_ten_env.send_cmd(Cmd.create(CMD_OUT_FLUSH)) |
|
async_ten_env.log_info("on_cmd sent flush") |
|
status_code, detail = StatusCode.OK, "success" |
|
cmd_result = CmdResult.create(status_code) |
|
cmd_result.set_property_string("detail", detail) |
|
await async_ten_env.return_result(cmd_result, cmd) |
|
elif cmd_name == CMD_IN_ON_USER_JOINED: |
|
self.users_count += 1 |
|
|
|
if self.config.greeting and self.users_count == 1: |
|
self.send_text_output(async_ten_env, self.config.greeting, True) |
|
|
|
status_code, detail = StatusCode.OK, "success" |
|
cmd_result = CmdResult.create(status_code) |
|
cmd_result.set_property_string("detail", detail) |
|
await async_ten_env.return_result(cmd_result, cmd) |
|
elif cmd_name == CMD_IN_ON_USER_LEFT: |
|
self.users_count -= 1 |
|
status_code, detail = StatusCode.OK, "success" |
|
cmd_result = CmdResult.create(status_code) |
|
cmd_result.set_property_string("detail", detail) |
|
await async_ten_env.return_result(cmd_result, cmd) |
|
else: |
|
await super().on_cmd(async_ten_env, cmd) |
|
|
|
async def on_data(self, async_ten_env: AsyncTenEnv, data: Data) -> None: |
|
data_name = data.get_name() |
|
async_ten_env.log_debug("on_data name {}".format(data_name)) |
|
|
|
|
|
is_final = get_property_bool(data, "is_final") |
|
input_text = get_property_string(data, "text") |
|
|
|
if not is_final: |
|
async_ten_env.log_debug("ignore non-final input") |
|
return |
|
if not input_text: |
|
async_ten_env.log_warn("ignore empty text") |
|
return |
|
|
|
async_ten_env.log_info(f"OnData input text: [{input_text}]") |
|
|
|
|
|
message = LLMChatCompletionUserMessageParam(role="user", content=input_text) |
|
await self.queue_input_item(False, messages=[message]) |
|
|
|
async def on_tools_update( |
|
self, async_ten_env: AsyncTenEnv, tool: LLMToolMetadata |
|
) -> None: |
|
return await super().on_tools_update(async_ten_env, tool) |
|
|
|
async def on_call_chat_completion( |
|
self, async_ten_env: AsyncTenEnv, **kargs: LLMCallCompletionArgs |
|
) -> any: |
|
kmessages: LLMChatCompletionUserMessageParam = kargs.get("messages", []) |
|
|
|
async_ten_env.log_info(f"on_call_chat_completion: {kmessages}") |
|
response = await self.client.get_chat_completions(kmessages, None) |
|
return response.to_json() |
|
|
|
async def on_data_chat_completion( |
|
self, async_ten_env: AsyncTenEnv, **kargs: LLMDataCompletionArgs |
|
) -> None: |
|
"""Run the chatflow asynchronously.""" |
|
kmessages: Iterable[LLMChatCompletionUserMessageParam] = kargs.get( |
|
"messages", [] |
|
) |
|
|
|
if len(kmessages) == 0: |
|
async_ten_env.log_error("No message in data") |
|
return |
|
|
|
messages = [] |
|
for message in kmessages: |
|
messages = messages + [self.message_to_dict(message)] |
|
|
|
self.memory_cache = [] |
|
memory = self.memory |
|
try: |
|
async_ten_env.log_info(f"for input text: [{messages}] memory: {memory}") |
|
tools = None |
|
no_tool = kargs.get("no_tool", False) |
|
|
|
for message in messages: |
|
if ( |
|
not isinstance(message.get("content"), str) |
|
and message.get("role") == "user" |
|
): |
|
non_artifact_content = [ |
|
item |
|
for item in message.get("content", []) |
|
if item.get("type") == "text" |
|
] |
|
non_artifact_message = { |
|
"role": message.get("role"), |
|
"content": non_artifact_content, |
|
} |
|
self.memory_cache = self.memory_cache + [ |
|
non_artifact_message, |
|
] |
|
else: |
|
self.memory_cache = self.memory_cache + [ |
|
message, |
|
] |
|
self.memory_cache = self.memory_cache + [{"role": "assistant", "content": ""}] |
|
|
|
tools = None |
|
if not no_tool and len(self.available_tools) > 0: |
|
tools = [] |
|
for tool in self.available_tools: |
|
tools.append(self._convert_tools_to_dict(tool)) |
|
async_ten_env.log_info(f"tool: {tool}") |
|
|
|
self.sentence_fragment = "" |
|
|
|
|
|
content_finished_event = asyncio.Event() |
|
|
|
self.tool_task_future = None |
|
|
|
message_id = str(uuid.uuid4())[:8] |
|
self.last_reasoning_ts = int(time.time() * 1000) |
|
|
|
|
|
async def handle_tool_call(tool_call): |
|
self.tool_task_future = asyncio.get_event_loop().create_future() |
|
async_ten_env.log_info(f"tool_call: {tool_call}") |
|
for tool in self.available_tools: |
|
if tool_call["function"]["name"] == tool.name: |
|
cmd: Cmd = Cmd.create(CMD_TOOL_CALL) |
|
cmd.set_property_string("name", tool.name) |
|
cmd.set_property_from_json( |
|
"arguments", tool_call["function"]["arguments"] |
|
) |
|
|
|
|
|
|
|
[result, _] = await async_ten_env.send_cmd(cmd) |
|
if result.get_status_code() == StatusCode.OK: |
|
tool_result: LLMToolResult = json.loads( |
|
result.get_property_to_json(CMD_PROPERTY_RESULT) |
|
) |
|
|
|
async_ten_env.log_info(f"tool_result: {tool_result}") |
|
|
|
|
|
if tool_result["type"] == "llmresult": |
|
result_content = tool_result["content"] |
|
if isinstance(result_content, str): |
|
tool_message = { |
|
"role": "assistant", |
|
"tool_calls": [tool_call], |
|
} |
|
new_message = { |
|
"role": "tool", |
|
"content": result_content, |
|
"tool_call_id": tool_call["id"], |
|
} |
|
await self.queue_input_item( |
|
True, messages=[tool_message, new_message], no_tool=True |
|
) |
|
else: |
|
async_ten_env.log_error( |
|
f"Unknown tool result content: {result_content}" |
|
) |
|
elif tool_result["type"] == "requery": |
|
|
|
self.memory_cache.pop() |
|
result_content = tool_result["content"] |
|
nonlocal message |
|
new_message = { |
|
"role": "user", |
|
"content": self._convert_to_content_parts( |
|
message["content"] |
|
), |
|
} |
|
new_message["content"] = new_message[ |
|
"content" |
|
] + self._convert_to_content_parts(result_content) |
|
await self.queue_input_item( |
|
True, messages=[new_message], no_tool=True |
|
) |
|
else: |
|
async_ten_env.log_error( |
|
f"Unknown tool result type: {tool_result}" |
|
) |
|
else: |
|
async_ten_env.log_error("Tool call failed") |
|
self.tool_task_future.set_result(None) |
|
|
|
async def handle_content_update(content: str): |
|
|
|
for item in reversed(self.memory_cache): |
|
if item.get("role") == "assistant": |
|
item["content"] = item["content"] + content |
|
break |
|
sentences, self.sentence_fragment = parse_sentences( |
|
self.sentence_fragment, content |
|
) |
|
for s in sentences: |
|
self.send_text_output(async_ten_env, s, False) |
|
|
|
async def handle_reasoning_update(think: str): |
|
ts = int(time.time() * 1000) |
|
if ts - self.last_reasoning_ts >= 200: |
|
self.last_reasoning_ts = ts |
|
self.send_reasoning_text_output(async_ten_env, message_id, think, False) |
|
|
|
|
|
async def handle_reasoning_update_finish(think: str): |
|
self.last_reasoning_ts = int(time.time() * 1000) |
|
self.send_reasoning_text_output(async_ten_env, message_id, think, True) |
|
|
|
async def handle_content_finished(_: str): |
|
|
|
if self.tool_task_future: |
|
await self.tool_task_future |
|
content_finished_event.set() |
|
|
|
listener = AsyncEventEmitter() |
|
listener.on("tool_call", handle_tool_call) |
|
listener.on("content_update", handle_content_update) |
|
listener.on("reasoning_update", handle_reasoning_update) |
|
listener.on("reasoning_update_finish", handle_reasoning_update_finish) |
|
listener.on("content_finished", handle_content_finished) |
|
|
|
|
|
await self.client.get_chat_completions_stream( |
|
memory + messages, tools, listener |
|
) |
|
|
|
|
|
await content_finished_event.wait() |
|
|
|
async_ten_env.log_info( |
|
f"Chat completion finished for input text: {messages}" |
|
) |
|
except asyncio.CancelledError: |
|
async_ten_env.log_info(f"Task cancelled: {messages}") |
|
except Exception: |
|
async_ten_env.log_error( |
|
f"Error in chat_completion: {traceback.format_exc()} for input text: {messages}" |
|
) |
|
finally: |
|
self.send_text_output(async_ten_env, "", True) |
|
|
|
for m in self.memory_cache: |
|
self._append_memory(m) |
|
|
|
def _convert_to_content_parts( |
|
self, content: Iterable[LLMChatCompletionContentPartParam] |
|
): |
|
content_parts = [] |
|
|
|
if isinstance(content, str): |
|
content_parts.append({"type": "text", "text": content}) |
|
else: |
|
for part in content: |
|
content_parts.append(part) |
|
return content_parts |
|
|
|
def _convert_tools_to_dict(self, tool: LLMToolMetadata): |
|
json_dict = { |
|
"type": "function", |
|
"function": { |
|
"name": tool.name, |
|
"description": tool.description, |
|
"parameters": { |
|
"type": "object", |
|
"properties": {}, |
|
"required": [], |
|
"additionalProperties": False, |
|
}, |
|
}, |
|
"strict": True, |
|
} |
|
|
|
for param in tool.parameters: |
|
json_dict["function"]["parameters"]["properties"][param.name] = { |
|
"type": param.type, |
|
"description": param.description, |
|
} |
|
if param.required: |
|
json_dict["function"]["parameters"]["required"].append(param.name) |
|
|
|
return json_dict |
|
|
|
def message_to_dict(self, message: LLMChatCompletionMessageParam): |
|
if message.get("content") is not None: |
|
if isinstance(message["content"], str): |
|
message["content"] = str(message["content"]) |
|
else: |
|
message["content"] = list(message["content"]) |
|
return message |
|
|
|
def _append_memory(self, message: str): |
|
if len(self.memory) > self.config.max_memory_length: |
|
removed_item = self.memory.pop(0) |
|
|
|
if removed_item.get("tool_calls") and self.memory[0].get("role") == "tool": |
|
self.memory.pop(0) |
|
self.memory.append(message) |
|
|
|
def send_reasoning_text_output( |
|
self, async_ten_env: AsyncTenEnv, msg_id:str, sentence: str, end_of_segment: bool |
|
): |
|
try: |
|
output_data = Data.create(CONTENT_DATA_OUT_NAME) |
|
output_data.set_property_string(DATA_OUT_PROPERTY_TEXT, json.dumps({ |
|
"id":msg_id, |
|
"data": { |
|
"text": sentence |
|
}, |
|
"type": "reasoning" |
|
})) |
|
output_data.set_property_bool( |
|
DATA_OUT_PROPERTY_END_OF_SEGMENT, end_of_segment |
|
) |
|
asyncio.create_task(async_ten_env.send_data(output_data)) |
|
|
|
|
|
|
|
except Exception: |
|
async_ten_env.log_warn( |
|
f"send sentence [{sentence}] failed, err: {traceback.format_exc()}") |