File size: 10,862 Bytes
87337b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
#
#
# Agora Real Time Engagement
# Created by Wei Hu in 2024-08.
# Copyright (c) 2024 Agora IO. All rights reserved.
#
#

from ten import (
    AudioFrame,
    VideoFrame,
    Extension,
    TenEnv,
    Cmd,
    StatusCode,
    CmdResult,
    Data,
)
import firebase_admin
from firebase_admin import credentials
from firebase_admin import firestore
import datetime
import asyncio
import queue
import threading
import json
from typing import List, Any

DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL = "is_final"
DATA_IN_TEXT_DATA_PROPERTY_STREAM_ID = "stream_id"
DATA_IN_TEXT_DATA_PROPERTY_TEXT = "text"
DATA_IN_TEXT_DATA_PROPERTY_ROLE = "role"

PROPERTY_CREDENTIALS = "credentials"
PROPERTY_CHANNEL_NAME = "channel_name"
PROPERTY_COLLECTION_NAME = "collection_name"
PROPERTY_TTL = "ttl"

RETRIEVE_CMD = "retrieve"
CMD_OUT_PROPERTY_RESPONSE = "response"
DOC_EXPIRE_PATH = "expireAt"
DOC_CONTENTS_PATH = "contents"
CONTENT_ROLE_PATH = "role"
CONTENT_TS_PATH = "ts"
CONTENT_STREAM_ID_PATH = "stream_id"
CONTENT_INPUT_PATH = "input"
DEFAULT_TTL = 1  # days


def get_current_time():
    # Get the current time
    start_time = datetime.datetime.now()
    # Get the number of microseconds since the Unix epoch
    unix_microseconds = int(start_time.timestamp() * 1_000_000)
    return unix_microseconds


def order_by_ts(contents: List[str]) -> List[Any]:
    tmp = []
    for c in contents:
        tmp.append(json.loads(c))
    sorted_contents = sorted(tmp, key=lambda x: x[CONTENT_TS_PATH])
    res = []
    for sc in sorted_contents:
        res.append(
            {
                CONTENT_ROLE_PATH: sc[CONTENT_ROLE_PATH],
                CONTENT_INPUT_PATH: sc[CONTENT_INPUT_PATH],
                CONTENT_STREAM_ID_PATH: sc.get(CONTENT_STREAM_ID_PATH, 0),
            }
        )
    return res


@firestore.transactional
def update_in_transaction(transaction, doc_ref, content):
    transaction.update(doc_ref, content)


@firestore.transactional
def read_in_transaction(transaction, doc_ref):
    doc = doc_ref.get(transaction=transaction)
    return doc.to_dict()


class TSDBFirestoreExtension(Extension):
    def __init__(self, name: str):
        super().__init__(name)
        self.stopped = False
        self.thread = None
        self.queue = queue.Queue()
        self.stopEvent = asyncio.Event()
        self.cmd_thread = None
        self.loop = None
        self.credentials = None
        self.channel_name = ""
        self.collection_name = ""
        self.ttl = DEFAULT_TTL
        self.client = None
        self.document_ref = None

        self.current_stream_id = 0
        self.cache = ""

    async def __thread_routine(self, ten_env: TenEnv):
        ten_env.log_info("__thread_routine start")
        self.loop = asyncio.get_running_loop()
        ten_env.on_start_done()
        await self.stopEvent.wait()

    async def stop_thread(self):
        self.stopEvent.set()

    def on_init(self, ten_env: TenEnv) -> None:
        ten_env.log_info("TSDBFirestoreExtension on_init")
        ten_env.on_init_done()

    def on_start(self, ten_env: TenEnv) -> None:
        ten_env.log_info("TSDBFirestoreExtension on_start")

        try:
            self.credentials = ten_env.get_property_to_json(PROPERTY_CREDENTIALS)
        except Exception as err:
            ten_env.log_error(
                f"GetProperty required {PROPERTY_CREDENTIALS} failed, err: {err}"
            )
            return

        try:
            self.channel_name = ten_env.get_property_string(PROPERTY_CHANNEL_NAME)
        except Exception as err:
            ten_env.log_error(
                f"GetProperty required {PROPERTY_CHANNEL_NAME} failed, err: {err}"
            )
            return

        try:
            self.collection_name = ten_env.get_property_string(PROPERTY_COLLECTION_NAME)
        except Exception as err:
            ten_env.log_error(
                f"GetProperty required {PROPERTY_COLLECTION_NAME} failed, err: {err}"
            )
            return

        # start firestore db
        cred = credentials.Certificate(json.loads(self.credentials))
        firebase_admin.initialize_app(cred)
        self.client = firestore.client()

        self.document_ref = self.client.collection(self.collection_name).document(
            self.channel_name
        )
        # update ttl
        expiration_time = datetime.datetime.now() + datetime.timedelta(days=self.ttl)
        exists = self.document_ref.get().exists
        if exists:
            self.document_ref.update({DOC_EXPIRE_PATH: expiration_time})
            ten_env.log_info(
                f"reset document ttl, {self.ttl} day(s), for the channel {self.channel_name}"
            )
        else:
            # not exists yet, set to create one
            self.document_ref.set({DOC_EXPIRE_PATH: expiration_time})
            ten_env.log_info(
                f"create new document and set ttl, {self.ttl} day(s), for the channel {self.channel_name}"
            )

        # start the loop to handle data in
        self.thread = threading.Thread(target=self.async_handle, args=[ten_env])
        self.thread.start()

        # start the loop to handle cmd in
        self.cmd_thread = threading.Thread(
            target=asyncio.run, args=(self.__thread_routine(ten_env),)
        )
        self.cmd_thread.start()

    def async_handle(self, ten_env: TenEnv) -> None:
        while not self.stopped:
            try:
                value = self.queue.get()
                if value is None:
                    ten_env.log_info("exit handle loop")
                    break
                ts, input_path, role, stream_id = value
                content_str = json.dumps(
                    {
                        CONTENT_ROLE_PATH: role,
                        CONTENT_INPUT_PATH: input_path,
                        CONTENT_TS_PATH: ts,
                        CONTENT_STREAM_ID_PATH: stream_id,
                    }
                )
                update_in_transaction(
                    self.client.transaction(),
                    self.document_ref,
                    {DOC_CONTENTS_PATH: firestore.ArrayUnion([content_str])},
                )
                ten_env.log_info(
                    f"append {content_str} to firestore document {self.channel_name}"
                )
            except Exception:
                ten_env.log_error("Failed to store chat contents")

    def on_stop(self, ten_env: TenEnv) -> None:
        ten_env.log_info("TSDBFirestoreExtension on_stop")

        # clear the queue and stop the thread to process data in
        self.stopped = True
        while not self.queue.empty():
            self.queue.get()
        self.queue.put(None)
        if self.thread is not None:
            self.thread.join()
            self.thread = None

        # stop the thread to process cmd in
        if self.cmd_thread is not None and self.cmd_thread.is_alive():
            asyncio.run_coroutine_threadsafe(self.stop_thread(), self.loop)
            self.cmd_thread.join()
            self.cmd_thread = None

        ten_env.on_stop_done()

    def on_deinit(self, ten_env: TenEnv) -> None:
        ten_env.log_info("TSDBFirestoreExtension on_deinit")
        ten_env.on_deinit_done()

    def on_cmd(self, ten_env: TenEnv, cmd: Cmd) -> None:
        try:
            cmd_name = cmd.get_name()
            ten_env.log_info(f"on_cmd name {cmd_name}")
            if cmd_name == RETRIEVE_CMD:
                asyncio.run_coroutine_threadsafe(self.retrieve(ten_env, cmd), self.loop)
            else:
                ten_env.log_info(f"unknown cmd name {cmd_name}")
                cmd_result = CmdResult.create(StatusCode.ERROR)
                ten_env.return_result(cmd_result, cmd)
        except Exception:
            ten_env.return_result(CmdResult.create(StatusCode.ERROR), cmd)

    async def retrieve(self, ten_env: TenEnv, cmd: Cmd):
        try:
            doc_dict = read_in_transaction(self.client.transaction(), self.document_ref)
            if DOC_CONTENTS_PATH in doc_dict:
                contents = doc_dict[DOC_CONTENTS_PATH]
                ten_env.log_info(f"after retrieve {contents}")
                ret = CmdResult.create(StatusCode.OK)
                ret.set_property_string(
                    CMD_OUT_PROPERTY_RESPONSE, json.dumps(order_by_ts(contents))
                )
                ten_env.return_result(ret, cmd)
            else:
                ten_env.log_info(f"no contents for the channel {self.channel_name} yet")
                ten_env.return_result(CmdResult.create(StatusCode.ERROR), cmd)
        except Exception:
            ten_env.log_error(
                f"Failed to read the document for the channel {self.channel_name}"
            )
            ten_env.return_result(CmdResult.create(StatusCode.ERROR), cmd)

    def on_data(self, ten_env: TenEnv, data: Data) -> None:
        ten_env.log_info("TSDBFirestoreExtension on_data")

        # assume 'data' is an object from which we can get properties
        is_final = False
        try:
            is_final = data.get_property_bool(DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL)
            if not is_final:
                ten_env.log_info("ignore non-final input")
                return
        except Exception as err:
            ten_env.log_info(
                f"OnData GetProperty {DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL} failed, err: {err}"
            )

        stream_id = 0
        try:
            stream_id = data.get_property_bool(DATA_IN_TEXT_DATA_PROPERTY_STREAM_ID)
        except Exception as err:
            ten_env.log_info(
                f"OnData GetProperty {DATA_IN_TEXT_DATA_PROPERTY_STREAM_ID} failed, err: {err}"
            )

        # get input text
        try:
            input_text = data.get_property_string(DATA_IN_TEXT_DATA_PROPERTY_TEXT)
            if not input_text:
                ten_env.log_info("ignore empty text")
                return
            ten_env.log_info(f"OnData input text: [{input_text}]")
        except Exception as err:
            ten_env.log_info(
                f"OnData GetProperty {DATA_IN_TEXT_DATA_PROPERTY_TEXT} failed, err: {err}"
            )
            return
        # get stream id
        try:
            role = data.get_property_string(DATA_IN_TEXT_DATA_PROPERTY_ROLE)
            if not role:
                ten_env.log_warn("ignore empty role")
                return
        except Exception as err:
            ten_env.log_info(
                f"OnData GetProperty {DATA_IN_TEXT_DATA_PROPERTY_ROLE} failed, err: {err}"
            )
            return

        ts = get_current_time()
        self.queue.put((ts, input_text, role, stream_id))

    def on_audio_frame(self, ten_env: TenEnv, audio_frame: AudioFrame) -> None:
        pass

    def on_video_frame(self, ten_env: TenEnv, video_frame: VideoFrame) -> None:
        pass