File size: 9,722 Bytes
87337b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 |
#
#
# Agora Real Time Engagement
# Created by Wei Hu in 2024-08.
# Copyright (c) 2024 Agora IO. All rights reserved.
#
#
import base64
import json
import threading
import time
import uuid
from ten import (
AudioFrame,
VideoFrame,
Extension,
TenEnv,
Cmd,
StatusCode,
CmdResult,
Data,
)
import asyncio
MAX_SIZE = 800 # 1 KB limit
OVERHEAD_ESTIMATE = 200 # Estimate for the overhead of metadata in the JSON
CMD_NAME_FLUSH = "flush"
TEXT_DATA_TEXT_FIELD = "text"
TEXT_DATA_FINAL_FIELD = "is_final"
TEXT_DATA_STREAM_ID_FIELD = "stream_id"
TEXT_DATA_END_OF_SEGMENT_FIELD = "end_of_segment"
MAX_CHUNK_SIZE_BYTES = 1024
def _text_to_base64_chunks(_: TenEnv, text: str, msg_id: str) -> list:
# Ensure msg_id does not exceed 50 characters
if len(msg_id) > 36:
raise ValueError("msg_id cannot exceed 36 characters.")
# Convert text to bytearray
byte_array = bytearray(text, "utf-8")
# Encode the bytearray into base64
base64_encoded = base64.b64encode(byte_array).decode("utf-8")
# Initialize list to hold the final chunks
chunks = []
# We'll split the base64 string dynamically based on the final byte size
part_index = 0
total_parts = (
None # We'll calculate total parts once we know how many chunks we create
)
# Process the base64-encoded content in chunks
current_position = 0
total_length = len(base64_encoded)
while current_position < total_length:
part_index += 1
# Start guessing the chunk size by limiting the base64 content part
estimated_chunk_size = MAX_CHUNK_SIZE_BYTES # We'll reduce this dynamically
content_chunk = ""
count = 0
while True:
# Create the content part of the chunk
content_chunk = base64_encoded[
current_position : current_position + estimated_chunk_size
]
# Format the chunk
formatted_chunk = f"{msg_id}|{part_index}|{total_parts if total_parts else '???'}|{content_chunk}"
# Check if the byte length of the formatted chunk exceeds the max allowed size
if len(bytearray(formatted_chunk, "utf-8")) <= MAX_CHUNK_SIZE_BYTES:
break
else:
# Reduce the estimated chunk size if the formatted chunk is too large
estimated_chunk_size -= 100 # Reduce content size gradually
count += 1
# ten_env.log_debug(f"chunk estimate guess: {count}")
# Add the current chunk to the list
chunks.append(formatted_chunk)
# Move to the next part of the content
current_position += estimated_chunk_size
# Now that we know the total number of parts, update the chunks with correct total_parts
total_parts = len(chunks)
updated_chunks = [chunk.replace("???", str(total_parts)) for chunk in chunks]
return updated_chunks
class MessageCollectorExtension(Extension):
def __init__(self, name: str):
super().__init__(name)
self.queue = asyncio.Queue()
self.loop = None
self.cached_text_map = {}
def on_init(self, ten_env: TenEnv) -> None:
ten_env.log_info("on_init")
ten_env.on_init_done()
def on_start(self, ten_env: TenEnv) -> None:
ten_env.log_info("on_start")
# TODO: read properties, initialize resources
self.loop = asyncio.new_event_loop()
def start_loop():
asyncio.set_event_loop(self.loop)
self.loop.run_forever()
threading.Thread(target=start_loop, args=[]).start()
self.loop.create_task(self._process_queue(ten_env))
ten_env.on_start_done()
def on_stop(self, ten_env: TenEnv) -> None:
ten_env.log_info("on_stop")
# TODO: clean up resources
ten_env.on_stop_done()
def on_deinit(self, ten_env: TenEnv) -> None:
ten_env.log_info("on_deinit")
ten_env.on_deinit_done()
def on_cmd(self, ten_env: TenEnv, cmd: Cmd) -> None:
cmd_name = cmd.get_name()
ten_env.log_info("on_cmd name {}".format(cmd_name))
# TODO: process cmd
cmd_result = CmdResult.create(StatusCode.OK)
ten_env.return_result(cmd_result, cmd)
def on_data(self, ten_env: TenEnv, data: Data) -> None:
"""
on_data receives data from ten graph.
current suppotend data:
- name: text_data
example:
{"name": "text_data", "properties": {"text": "hello", "is_final": true, "stream_id": 123, "end_of_segment": true}}
"""
# ten_env.log_debug(f"on_data")
text = ""
final = True
stream_id = 0
end_of_segment = False
# Add the raw data type if the data is raw text data
if data.get_name() == "text_data":
try:
text = data.get_property_string(TEXT_DATA_TEXT_FIELD)
except Exception as e:
ten_env.log_error(
f"on_data get_property_string {TEXT_DATA_TEXT_FIELD} error: {e}"
)
try:
final = data.get_property_bool(TEXT_DATA_FINAL_FIELD)
except Exception:
pass
try:
stream_id = data.get_property_int(TEXT_DATA_STREAM_ID_FIELD)
except Exception:
pass
try:
end_of_segment = data.get_property_bool(TEXT_DATA_END_OF_SEGMENT_FIELD)
except Exception as e:
ten_env.log_warn(
f"on_data get_property_bool {TEXT_DATA_END_OF_SEGMENT_FIELD} error: {e}"
)
ten_env.log_info(
f"on_data {TEXT_DATA_TEXT_FIELD}: {text} {TEXT_DATA_FINAL_FIELD}: {final} {TEXT_DATA_STREAM_ID_FIELD}: {stream_id} {TEXT_DATA_END_OF_SEGMENT_FIELD}: {end_of_segment}"
)
# We cache all final text data and append the non-final text data to the cached data
# until the end of the segment.
if end_of_segment:
if stream_id in self.cached_text_map:
text = self.cached_text_map[stream_id] + text
del self.cached_text_map[stream_id]
else:
if final:
if stream_id in self.cached_text_map:
text = self.cached_text_map[stream_id] + text
self.cached_text_map[stream_id] = text
# Generate a unique message ID for this batch of parts
message_id = str(uuid.uuid4())[:8]
# Prepare the main JSON structure without the text field
base_msg_data = {
"is_final": end_of_segment,
"stream_id": stream_id,
"message_id": message_id, # Add message_id to identify the split message
"data_type": "transcribe",
"text_ts": int(time.time() * 1000), # Convert to milliseconds
"text": text,
}
try:
chunks = _text_to_base64_chunks(ten_env, json.dumps(base_msg_data), message_id)
for chunk in chunks:
asyncio.run_coroutine_threadsafe(self._queue_message(chunk), self.loop)
except Exception as e:
ten_env.log_warn(f"on_data new_data error: {e}")
elif data.get_name() == "content_data":
try:
text = data.get_property_string(TEXT_DATA_TEXT_FIELD)
except Exception as e:
ten_env.log_error(
f"on_data get_property_string {TEXT_DATA_TEXT_FIELD} error: {e}"
)
try:
end_of_segment = data.get_property_bool(TEXT_DATA_END_OF_SEGMENT_FIELD)
except Exception as e:
ten_env.log_warn(
f"on_data get_property_bool {TEXT_DATA_END_OF_SEGMENT_FIELD} error: {e}"
)
ten_env.log_info(
f"on_data {TEXT_DATA_TEXT_FIELD}: {text}"
)
# Generate a unique message ID for this batch of parts
message_id = str(uuid.uuid4())[:8]
# Prepare the main JSON structure without the text field
base_msg_data = {
"is_final": end_of_segment,
"stream_id": stream_id,
"message_id": message_id, # Add message_id to identify the split message
"data_type": "raw",
"text_ts": int(time.time() * 1000), # Convert to milliseconds
"text": text,
}
try:
chunks = _text_to_base64_chunks(ten_env, json.dumps(base_msg_data), message_id)
for chunk in chunks:
asyncio.run_coroutine_threadsafe(self._queue_message(chunk), self.loop)
except Exception as e:
ten_env.log_warn(f"on_data new_data error: {e}")
def on_audio_frame(self, ten_env: TenEnv, audio_frame: AudioFrame) -> None:
# TODO: process pcm frame
pass
def on_video_frame(self, ten_env: TenEnv, video_frame: VideoFrame) -> None:
# TODO: process image frame
pass
async def _queue_message(self, data: str):
await self.queue.put(data)
async def _process_queue(self, ten_env: TenEnv):
while True:
data = await self.queue.get()
if data is None:
break
# process data
ten_data = Data.create("data")
ten_data.set_property_buf("data", data.encode())
ten_env.send_data(ten_data)
self.queue.task_done()
await asyncio.sleep(0.04)
|