File size: 9,722 Bytes
87337b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
#
#
# Agora Real Time Engagement
# Created by Wei Hu in 2024-08.
# Copyright (c) 2024 Agora IO. All rights reserved.
#
#
import base64
import json
import threading
import time
import uuid
from ten import (
    AudioFrame,
    VideoFrame,
    Extension,
    TenEnv,
    Cmd,
    StatusCode,
    CmdResult,
    Data,
)
import asyncio

MAX_SIZE = 800  # 1 KB limit
OVERHEAD_ESTIMATE = 200  # Estimate for the overhead of metadata in the JSON

CMD_NAME_FLUSH = "flush"

TEXT_DATA_TEXT_FIELD = "text"
TEXT_DATA_FINAL_FIELD = "is_final"
TEXT_DATA_STREAM_ID_FIELD = "stream_id"
TEXT_DATA_END_OF_SEGMENT_FIELD = "end_of_segment"

MAX_CHUNK_SIZE_BYTES = 1024


def _text_to_base64_chunks(_: TenEnv, text: str, msg_id: str) -> list:
    # Ensure msg_id does not exceed 50 characters
    if len(msg_id) > 36:
        raise ValueError("msg_id cannot exceed 36 characters.")

    # Convert text to bytearray
    byte_array = bytearray(text, "utf-8")

    # Encode the bytearray into base64
    base64_encoded = base64.b64encode(byte_array).decode("utf-8")

    # Initialize list to hold the final chunks
    chunks = []

    # We'll split the base64 string dynamically based on the final byte size
    part_index = 0
    total_parts = (
        None  # We'll calculate total parts once we know how many chunks we create
    )

    # Process the base64-encoded content in chunks
    current_position = 0
    total_length = len(base64_encoded)

    while current_position < total_length:
        part_index += 1

        # Start guessing the chunk size by limiting the base64 content part
        estimated_chunk_size = MAX_CHUNK_SIZE_BYTES  # We'll reduce this dynamically
        content_chunk = ""
        count = 0
        while True:
            # Create the content part of the chunk
            content_chunk = base64_encoded[
                current_position : current_position + estimated_chunk_size
            ]

            # Format the chunk
            formatted_chunk = f"{msg_id}|{part_index}|{total_parts if total_parts else '???'}|{content_chunk}"

            # Check if the byte length of the formatted chunk exceeds the max allowed size
            if len(bytearray(formatted_chunk, "utf-8")) <= MAX_CHUNK_SIZE_BYTES:
                break
            else:
                # Reduce the estimated chunk size if the formatted chunk is too large
                estimated_chunk_size -= 100  # Reduce content size gradually
                count += 1

        # ten_env.log_debug(f"chunk estimate guess: {count}")

        # Add the current chunk to the list
        chunks.append(formatted_chunk)
        # Move to the next part of the content
        current_position += estimated_chunk_size

    # Now that we know the total number of parts, update the chunks with correct total_parts
    total_parts = len(chunks)
    updated_chunks = [chunk.replace("???", str(total_parts)) for chunk in chunks]

    return updated_chunks


class MessageCollectorExtension(Extension):
    def __init__(self, name: str):
        super().__init__(name)
        self.queue = asyncio.Queue()
        self.loop = None
        self.cached_text_map = {}

    def on_init(self, ten_env: TenEnv) -> None:
        ten_env.log_info("on_init")
        ten_env.on_init_done()

    def on_start(self, ten_env: TenEnv) -> None:
        ten_env.log_info("on_start")

        # TODO: read properties, initialize resources
        self.loop = asyncio.new_event_loop()

        def start_loop():
            asyncio.set_event_loop(self.loop)
            self.loop.run_forever()

        threading.Thread(target=start_loop, args=[]).start()

        self.loop.create_task(self._process_queue(ten_env))

        ten_env.on_start_done()

    def on_stop(self, ten_env: TenEnv) -> None:
        ten_env.log_info("on_stop")

        # TODO: clean up resources

        ten_env.on_stop_done()

    def on_deinit(self, ten_env: TenEnv) -> None:
        ten_env.log_info("on_deinit")
        ten_env.on_deinit_done()

    def on_cmd(self, ten_env: TenEnv, cmd: Cmd) -> None:
        cmd_name = cmd.get_name()
        ten_env.log_info("on_cmd name {}".format(cmd_name))

        # TODO: process cmd

        cmd_result = CmdResult.create(StatusCode.OK)
        ten_env.return_result(cmd_result, cmd)

    def on_data(self, ten_env: TenEnv, data: Data) -> None:
        """
        on_data receives data from ten graph.
        current suppotend data:
          - name: text_data
            example:
            {"name": "text_data", "properties": {"text": "hello", "is_final": true, "stream_id": 123, "end_of_segment": true}}
        """
        # ten_env.log_debug(f"on_data")
        text = ""
        final = True
        stream_id = 0
        end_of_segment = False


        # Add the raw data type if the data is raw text data
        if data.get_name() == "text_data":
            try:
                text = data.get_property_string(TEXT_DATA_TEXT_FIELD)
            except Exception as e:
                ten_env.log_error(
                    f"on_data get_property_string {TEXT_DATA_TEXT_FIELD} error: {e}"
                )

            try:
                final = data.get_property_bool(TEXT_DATA_FINAL_FIELD)
            except Exception:
                pass

            try:
                stream_id = data.get_property_int(TEXT_DATA_STREAM_ID_FIELD)
            except Exception:
                pass

            try:
                end_of_segment = data.get_property_bool(TEXT_DATA_END_OF_SEGMENT_FIELD)
            except Exception as e:
                ten_env.log_warn(
                    f"on_data get_property_bool {TEXT_DATA_END_OF_SEGMENT_FIELD} error: {e}"
                )

            ten_env.log_info(
                f"on_data {TEXT_DATA_TEXT_FIELD}: {text} {TEXT_DATA_FINAL_FIELD}: {final} {TEXT_DATA_STREAM_ID_FIELD}: {stream_id} {TEXT_DATA_END_OF_SEGMENT_FIELD}: {end_of_segment}"
            )

            # We cache all final text data and append the non-final text data to the cached data
            # until the end of the segment.
            if end_of_segment:
                if stream_id in self.cached_text_map:
                    text = self.cached_text_map[stream_id] + text
                    del self.cached_text_map[stream_id]
            else:
                if final:
                    if stream_id in self.cached_text_map:
                        text = self.cached_text_map[stream_id] + text

                    self.cached_text_map[stream_id] = text

            # Generate a unique message ID for this batch of parts
            message_id = str(uuid.uuid4())[:8]

            # Prepare the main JSON structure without the text field
            base_msg_data = {
                "is_final": end_of_segment,
                "stream_id": stream_id,
                "message_id": message_id,  # Add message_id to identify the split message
                "data_type": "transcribe",
                "text_ts": int(time.time() * 1000),  # Convert to milliseconds
                "text": text,
            }


            try:
                chunks = _text_to_base64_chunks(ten_env, json.dumps(base_msg_data), message_id)
                for chunk in chunks:
                    asyncio.run_coroutine_threadsafe(self._queue_message(chunk), self.loop)

            except Exception as e:
                ten_env.log_warn(f"on_data new_data error: {e}")
        elif data.get_name() == "content_data":
            try:
                text = data.get_property_string(TEXT_DATA_TEXT_FIELD)
            except Exception as e:
                ten_env.log_error(
                    f"on_data get_property_string {TEXT_DATA_TEXT_FIELD} error: {e}"
                )

            try:
                end_of_segment = data.get_property_bool(TEXT_DATA_END_OF_SEGMENT_FIELD)
            except Exception as e:
                ten_env.log_warn(
                    f"on_data get_property_bool {TEXT_DATA_END_OF_SEGMENT_FIELD} error: {e}"
                )

            ten_env.log_info(
                f"on_data {TEXT_DATA_TEXT_FIELD}: {text}"
            )

            # Generate a unique message ID for this batch of parts
            message_id = str(uuid.uuid4())[:8]

            # Prepare the main JSON structure without the text field
            base_msg_data = {
                "is_final": end_of_segment,
                "stream_id": stream_id,
                "message_id": message_id,  # Add message_id to identify the split message
                "data_type": "raw",
                "text_ts": int(time.time() * 1000),  # Convert to milliseconds
                "text": text,
            }


            try:
                chunks = _text_to_base64_chunks(ten_env, json.dumps(base_msg_data), message_id)
                for chunk in chunks:
                    asyncio.run_coroutine_threadsafe(self._queue_message(chunk), self.loop)

            except Exception as e:
                ten_env.log_warn(f"on_data new_data error: {e}")

    def on_audio_frame(self, ten_env: TenEnv, audio_frame: AudioFrame) -> None:
        # TODO: process pcm frame
        pass

    def on_video_frame(self, ten_env: TenEnv, video_frame: VideoFrame) -> None:
        # TODO: process image frame
        pass

    async def _queue_message(self, data: str):
        await self.queue.put(data)

    async def _process_queue(self, ten_env: TenEnv):
        while True:
            data = await self.queue.get()
            if data is None:
                break
            # process data
            ten_data = Data.create("data")
            ten_data.set_property_buf("data", data.encode())
            ten_env.send_data(ten_data)
            self.queue.task_done()
            await asyncio.sleep(0.04)