File size: 10,862 Bytes
87337b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 |
#
#
# Agora Real Time Engagement
# Created by Wei Hu in 2024-08.
# Copyright (c) 2024 Agora IO. All rights reserved.
#
#
from ten import (
AudioFrame,
VideoFrame,
Extension,
TenEnv,
Cmd,
StatusCode,
CmdResult,
Data,
)
import firebase_admin
from firebase_admin import credentials
from firebase_admin import firestore
import datetime
import asyncio
import queue
import threading
import json
from typing import List, Any
DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL = "is_final"
DATA_IN_TEXT_DATA_PROPERTY_STREAM_ID = "stream_id"
DATA_IN_TEXT_DATA_PROPERTY_TEXT = "text"
DATA_IN_TEXT_DATA_PROPERTY_ROLE = "role"
PROPERTY_CREDENTIALS = "credentials"
PROPERTY_CHANNEL_NAME = "channel_name"
PROPERTY_COLLECTION_NAME = "collection_name"
PROPERTY_TTL = "ttl"
RETRIEVE_CMD = "retrieve"
CMD_OUT_PROPERTY_RESPONSE = "response"
DOC_EXPIRE_PATH = "expireAt"
DOC_CONTENTS_PATH = "contents"
CONTENT_ROLE_PATH = "role"
CONTENT_TS_PATH = "ts"
CONTENT_STREAM_ID_PATH = "stream_id"
CONTENT_INPUT_PATH = "input"
DEFAULT_TTL = 1 # days
def get_current_time():
# Get the current time
start_time = datetime.datetime.now()
# Get the number of microseconds since the Unix epoch
unix_microseconds = int(start_time.timestamp() * 1_000_000)
return unix_microseconds
def order_by_ts(contents: List[str]) -> List[Any]:
tmp = []
for c in contents:
tmp.append(json.loads(c))
sorted_contents = sorted(tmp, key=lambda x: x[CONTENT_TS_PATH])
res = []
for sc in sorted_contents:
res.append(
{
CONTENT_ROLE_PATH: sc[CONTENT_ROLE_PATH],
CONTENT_INPUT_PATH: sc[CONTENT_INPUT_PATH],
CONTENT_STREAM_ID_PATH: sc.get(CONTENT_STREAM_ID_PATH, 0),
}
)
return res
@firestore.transactional
def update_in_transaction(transaction, doc_ref, content):
transaction.update(doc_ref, content)
@firestore.transactional
def read_in_transaction(transaction, doc_ref):
doc = doc_ref.get(transaction=transaction)
return doc.to_dict()
class TSDBFirestoreExtension(Extension):
def __init__(self, name: str):
super().__init__(name)
self.stopped = False
self.thread = None
self.queue = queue.Queue()
self.stopEvent = asyncio.Event()
self.cmd_thread = None
self.loop = None
self.credentials = None
self.channel_name = ""
self.collection_name = ""
self.ttl = DEFAULT_TTL
self.client = None
self.document_ref = None
self.current_stream_id = 0
self.cache = ""
async def __thread_routine(self, ten_env: TenEnv):
ten_env.log_info("__thread_routine start")
self.loop = asyncio.get_running_loop()
ten_env.on_start_done()
await self.stopEvent.wait()
async def stop_thread(self):
self.stopEvent.set()
def on_init(self, ten_env: TenEnv) -> None:
ten_env.log_info("TSDBFirestoreExtension on_init")
ten_env.on_init_done()
def on_start(self, ten_env: TenEnv) -> None:
ten_env.log_info("TSDBFirestoreExtension on_start")
try:
self.credentials = ten_env.get_property_to_json(PROPERTY_CREDENTIALS)
except Exception as err:
ten_env.log_error(
f"GetProperty required {PROPERTY_CREDENTIALS} failed, err: {err}"
)
return
try:
self.channel_name = ten_env.get_property_string(PROPERTY_CHANNEL_NAME)
except Exception as err:
ten_env.log_error(
f"GetProperty required {PROPERTY_CHANNEL_NAME} failed, err: {err}"
)
return
try:
self.collection_name = ten_env.get_property_string(PROPERTY_COLLECTION_NAME)
except Exception as err:
ten_env.log_error(
f"GetProperty required {PROPERTY_COLLECTION_NAME} failed, err: {err}"
)
return
# start firestore db
cred = credentials.Certificate(json.loads(self.credentials))
firebase_admin.initialize_app(cred)
self.client = firestore.client()
self.document_ref = self.client.collection(self.collection_name).document(
self.channel_name
)
# update ttl
expiration_time = datetime.datetime.now() + datetime.timedelta(days=self.ttl)
exists = self.document_ref.get().exists
if exists:
self.document_ref.update({DOC_EXPIRE_PATH: expiration_time})
ten_env.log_info(
f"reset document ttl, {self.ttl} day(s), for the channel {self.channel_name}"
)
else:
# not exists yet, set to create one
self.document_ref.set({DOC_EXPIRE_PATH: expiration_time})
ten_env.log_info(
f"create new document and set ttl, {self.ttl} day(s), for the channel {self.channel_name}"
)
# start the loop to handle data in
self.thread = threading.Thread(target=self.async_handle, args=[ten_env])
self.thread.start()
# start the loop to handle cmd in
self.cmd_thread = threading.Thread(
target=asyncio.run, args=(self.__thread_routine(ten_env),)
)
self.cmd_thread.start()
def async_handle(self, ten_env: TenEnv) -> None:
while not self.stopped:
try:
value = self.queue.get()
if value is None:
ten_env.log_info("exit handle loop")
break
ts, input_path, role, stream_id = value
content_str = json.dumps(
{
CONTENT_ROLE_PATH: role,
CONTENT_INPUT_PATH: input_path,
CONTENT_TS_PATH: ts,
CONTENT_STREAM_ID_PATH: stream_id,
}
)
update_in_transaction(
self.client.transaction(),
self.document_ref,
{DOC_CONTENTS_PATH: firestore.ArrayUnion([content_str])},
)
ten_env.log_info(
f"append {content_str} to firestore document {self.channel_name}"
)
except Exception:
ten_env.log_error("Failed to store chat contents")
def on_stop(self, ten_env: TenEnv) -> None:
ten_env.log_info("TSDBFirestoreExtension on_stop")
# clear the queue and stop the thread to process data in
self.stopped = True
while not self.queue.empty():
self.queue.get()
self.queue.put(None)
if self.thread is not None:
self.thread.join()
self.thread = None
# stop the thread to process cmd in
if self.cmd_thread is not None and self.cmd_thread.is_alive():
asyncio.run_coroutine_threadsafe(self.stop_thread(), self.loop)
self.cmd_thread.join()
self.cmd_thread = None
ten_env.on_stop_done()
def on_deinit(self, ten_env: TenEnv) -> None:
ten_env.log_info("TSDBFirestoreExtension on_deinit")
ten_env.on_deinit_done()
def on_cmd(self, ten_env: TenEnv, cmd: Cmd) -> None:
try:
cmd_name = cmd.get_name()
ten_env.log_info(f"on_cmd name {cmd_name}")
if cmd_name == RETRIEVE_CMD:
asyncio.run_coroutine_threadsafe(self.retrieve(ten_env, cmd), self.loop)
else:
ten_env.log_info(f"unknown cmd name {cmd_name}")
cmd_result = CmdResult.create(StatusCode.ERROR)
ten_env.return_result(cmd_result, cmd)
except Exception:
ten_env.return_result(CmdResult.create(StatusCode.ERROR), cmd)
async def retrieve(self, ten_env: TenEnv, cmd: Cmd):
try:
doc_dict = read_in_transaction(self.client.transaction(), self.document_ref)
if DOC_CONTENTS_PATH in doc_dict:
contents = doc_dict[DOC_CONTENTS_PATH]
ten_env.log_info(f"after retrieve {contents}")
ret = CmdResult.create(StatusCode.OK)
ret.set_property_string(
CMD_OUT_PROPERTY_RESPONSE, json.dumps(order_by_ts(contents))
)
ten_env.return_result(ret, cmd)
else:
ten_env.log_info(f"no contents for the channel {self.channel_name} yet")
ten_env.return_result(CmdResult.create(StatusCode.ERROR), cmd)
except Exception:
ten_env.log_error(
f"Failed to read the document for the channel {self.channel_name}"
)
ten_env.return_result(CmdResult.create(StatusCode.ERROR), cmd)
def on_data(self, ten_env: TenEnv, data: Data) -> None:
ten_env.log_info("TSDBFirestoreExtension on_data")
# assume 'data' is an object from which we can get properties
is_final = False
try:
is_final = data.get_property_bool(DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL)
if not is_final:
ten_env.log_info("ignore non-final input")
return
except Exception as err:
ten_env.log_info(
f"OnData GetProperty {DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL} failed, err: {err}"
)
stream_id = 0
try:
stream_id = data.get_property_bool(DATA_IN_TEXT_DATA_PROPERTY_STREAM_ID)
except Exception as err:
ten_env.log_info(
f"OnData GetProperty {DATA_IN_TEXT_DATA_PROPERTY_STREAM_ID} failed, err: {err}"
)
# get input text
try:
input_text = data.get_property_string(DATA_IN_TEXT_DATA_PROPERTY_TEXT)
if not input_text:
ten_env.log_info("ignore empty text")
return
ten_env.log_info(f"OnData input text: [{input_text}]")
except Exception as err:
ten_env.log_info(
f"OnData GetProperty {DATA_IN_TEXT_DATA_PROPERTY_TEXT} failed, err: {err}"
)
return
# get stream id
try:
role = data.get_property_string(DATA_IN_TEXT_DATA_PROPERTY_ROLE)
if not role:
ten_env.log_warn("ignore empty role")
return
except Exception as err:
ten_env.log_info(
f"OnData GetProperty {DATA_IN_TEXT_DATA_PROPERTY_ROLE} failed, err: {err}"
)
return
ts = get_current_time()
self.queue.put((ts, input_text, role, stream_id))
def on_audio_frame(self, ten_env: TenEnv, audio_frame: AudioFrame) -> None:
pass
def on_video_frame(self, ten_env: TenEnv, video_frame: VideoFrame) -> None:
pass
|