|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import base64 |
|
import json |
|
import threading |
|
import time |
|
import uuid |
|
from ten import ( |
|
AudioFrame, |
|
VideoFrame, |
|
Extension, |
|
TenEnv, |
|
Cmd, |
|
StatusCode, |
|
CmdResult, |
|
Data, |
|
) |
|
import asyncio |
|
|
|
MAX_SIZE = 800 |
|
OVERHEAD_ESTIMATE = 200 |
|
|
|
CMD_NAME_FLUSH = "flush" |
|
|
|
TEXT_DATA_TEXT_FIELD = "text" |
|
TEXT_DATA_FINAL_FIELD = "is_final" |
|
TEXT_DATA_STREAM_ID_FIELD = "stream_id" |
|
TEXT_DATA_END_OF_SEGMENT_FIELD = "end_of_segment" |
|
|
|
MAX_CHUNK_SIZE_BYTES = 1024 |
|
|
|
|
|
def _text_to_base64_chunks(_: TenEnv, text: str, msg_id: str) -> list: |
|
|
|
if len(msg_id) > 36: |
|
raise ValueError("msg_id cannot exceed 36 characters.") |
|
|
|
|
|
byte_array = bytearray(text, "utf-8") |
|
|
|
|
|
base64_encoded = base64.b64encode(byte_array).decode("utf-8") |
|
|
|
|
|
chunks = [] |
|
|
|
|
|
part_index = 0 |
|
total_parts = ( |
|
None |
|
) |
|
|
|
|
|
current_position = 0 |
|
total_length = len(base64_encoded) |
|
|
|
while current_position < total_length: |
|
part_index += 1 |
|
|
|
|
|
estimated_chunk_size = MAX_CHUNK_SIZE_BYTES |
|
content_chunk = "" |
|
count = 0 |
|
while True: |
|
|
|
content_chunk = base64_encoded[ |
|
current_position : current_position + estimated_chunk_size |
|
] |
|
|
|
|
|
formatted_chunk = f"{msg_id}|{part_index}|{total_parts if total_parts else '???'}|{content_chunk}" |
|
|
|
|
|
if len(bytearray(formatted_chunk, "utf-8")) <= MAX_CHUNK_SIZE_BYTES: |
|
break |
|
else: |
|
|
|
estimated_chunk_size -= 100 |
|
count += 1 |
|
|
|
|
|
|
|
|
|
chunks.append(formatted_chunk) |
|
|
|
current_position += estimated_chunk_size |
|
|
|
|
|
total_parts = len(chunks) |
|
updated_chunks = [chunk.replace("???", str(total_parts)) for chunk in chunks] |
|
|
|
return updated_chunks |
|
|
|
|
|
class MessageCollectorExtension(Extension): |
|
def __init__(self, name: str): |
|
super().__init__(name) |
|
self.queue = asyncio.Queue() |
|
self.loop = None |
|
self.cached_text_map = {} |
|
|
|
def on_init(self, ten_env: TenEnv) -> None: |
|
ten_env.log_info("on_init") |
|
ten_env.on_init_done() |
|
|
|
def on_start(self, ten_env: TenEnv) -> None: |
|
ten_env.log_info("on_start") |
|
|
|
|
|
self.loop = asyncio.new_event_loop() |
|
|
|
def start_loop(): |
|
asyncio.set_event_loop(self.loop) |
|
self.loop.run_forever() |
|
|
|
threading.Thread(target=start_loop, args=[]).start() |
|
|
|
self.loop.create_task(self._process_queue(ten_env)) |
|
|
|
ten_env.on_start_done() |
|
|
|
def on_stop(self, ten_env: TenEnv) -> None: |
|
ten_env.log_info("on_stop") |
|
|
|
|
|
|
|
ten_env.on_stop_done() |
|
|
|
def on_deinit(self, ten_env: TenEnv) -> None: |
|
ten_env.log_info("on_deinit") |
|
ten_env.on_deinit_done() |
|
|
|
def on_cmd(self, ten_env: TenEnv, cmd: Cmd) -> None: |
|
cmd_name = cmd.get_name() |
|
ten_env.log_info("on_cmd name {}".format(cmd_name)) |
|
|
|
|
|
|
|
cmd_result = CmdResult.create(StatusCode.OK) |
|
ten_env.return_result(cmd_result, cmd) |
|
|
|
def on_data(self, ten_env: TenEnv, data: Data) -> None: |
|
""" |
|
on_data receives data from ten graph. |
|
current suppotend data: |
|
- name: text_data |
|
example: |
|
{"name": "text_data", "properties": {"text": "hello", "is_final": true, "stream_id": 123, "end_of_segment": true}} |
|
""" |
|
|
|
text = "" |
|
final = True |
|
stream_id = 0 |
|
end_of_segment = False |
|
|
|
|
|
|
|
if data.get_name() == "text_data": |
|
try: |
|
text = data.get_property_string(TEXT_DATA_TEXT_FIELD) |
|
except Exception as e: |
|
ten_env.log_error( |
|
f"on_data get_property_string {TEXT_DATA_TEXT_FIELD} error: {e}" |
|
) |
|
|
|
try: |
|
final = data.get_property_bool(TEXT_DATA_FINAL_FIELD) |
|
except Exception: |
|
pass |
|
|
|
try: |
|
stream_id = data.get_property_int(TEXT_DATA_STREAM_ID_FIELD) |
|
except Exception: |
|
pass |
|
|
|
try: |
|
end_of_segment = data.get_property_bool(TEXT_DATA_END_OF_SEGMENT_FIELD) |
|
except Exception as e: |
|
ten_env.log_warn( |
|
f"on_data get_property_bool {TEXT_DATA_END_OF_SEGMENT_FIELD} error: {e}" |
|
) |
|
|
|
ten_env.log_info( |
|
f"on_data {TEXT_DATA_TEXT_FIELD}: {text} {TEXT_DATA_FINAL_FIELD}: {final} {TEXT_DATA_STREAM_ID_FIELD}: {stream_id} {TEXT_DATA_END_OF_SEGMENT_FIELD}: {end_of_segment}" |
|
) |
|
|
|
|
|
|
|
if end_of_segment: |
|
if stream_id in self.cached_text_map: |
|
text = self.cached_text_map[stream_id] + text |
|
del self.cached_text_map[stream_id] |
|
else: |
|
if final: |
|
if stream_id in self.cached_text_map: |
|
text = self.cached_text_map[stream_id] + text |
|
|
|
self.cached_text_map[stream_id] = text |
|
|
|
|
|
message_id = str(uuid.uuid4())[:8] |
|
|
|
|
|
base_msg_data = { |
|
"is_final": end_of_segment, |
|
"stream_id": stream_id, |
|
"message_id": message_id, |
|
"data_type": "transcribe", |
|
"text_ts": int(time.time() * 1000), |
|
"text": text, |
|
} |
|
|
|
|
|
try: |
|
chunks = _text_to_base64_chunks(ten_env, json.dumps(base_msg_data), message_id) |
|
for chunk in chunks: |
|
asyncio.run_coroutine_threadsafe(self._queue_message(chunk), self.loop) |
|
|
|
except Exception as e: |
|
ten_env.log_warn(f"on_data new_data error: {e}") |
|
elif data.get_name() == "content_data": |
|
try: |
|
text = data.get_property_string(TEXT_DATA_TEXT_FIELD) |
|
except Exception as e: |
|
ten_env.log_error( |
|
f"on_data get_property_string {TEXT_DATA_TEXT_FIELD} error: {e}" |
|
) |
|
|
|
try: |
|
end_of_segment = data.get_property_bool(TEXT_DATA_END_OF_SEGMENT_FIELD) |
|
except Exception as e: |
|
ten_env.log_warn( |
|
f"on_data get_property_bool {TEXT_DATA_END_OF_SEGMENT_FIELD} error: {e}" |
|
) |
|
|
|
ten_env.log_info( |
|
f"on_data {TEXT_DATA_TEXT_FIELD}: {text}" |
|
) |
|
|
|
|
|
message_id = str(uuid.uuid4())[:8] |
|
|
|
|
|
base_msg_data = { |
|
"is_final": end_of_segment, |
|
"stream_id": stream_id, |
|
"message_id": message_id, |
|
"data_type": "raw", |
|
"text_ts": int(time.time() * 1000), |
|
"text": text, |
|
} |
|
|
|
|
|
try: |
|
chunks = _text_to_base64_chunks(ten_env, json.dumps(base_msg_data), message_id) |
|
for chunk in chunks: |
|
asyncio.run_coroutine_threadsafe(self._queue_message(chunk), self.loop) |
|
|
|
except Exception as e: |
|
ten_env.log_warn(f"on_data new_data error: {e}") |
|
|
|
def on_audio_frame(self, ten_env: TenEnv, audio_frame: AudioFrame) -> None: |
|
|
|
pass |
|
|
|
def on_video_frame(self, ten_env: TenEnv, video_frame: VideoFrame) -> None: |
|
|
|
pass |
|
|
|
async def _queue_message(self, data: str): |
|
await self.queue.put(data) |
|
|
|
async def _process_queue(self, ten_env: TenEnv): |
|
while True: |
|
data = await self.queue.get() |
|
if data is None: |
|
break |
|
|
|
ten_data = Data.create("data") |
|
ten_data.set_property_buf("data", data.encode()) |
|
ten_env.send_data(ten_data) |
|
self.queue.task_done() |
|
await asyncio.sleep(0.04) |
|
|