|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from ten import ( |
|
AudioFrame, |
|
VideoFrame, |
|
Extension, |
|
TenEnv, |
|
Cmd, |
|
StatusCode, |
|
CmdResult, |
|
Data, |
|
) |
|
import firebase_admin |
|
from firebase_admin import credentials |
|
from firebase_admin import firestore |
|
import datetime |
|
import asyncio |
|
import queue |
|
import threading |
|
import json |
|
from typing import List, Any |
|
|
|
DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL = "is_final" |
|
DATA_IN_TEXT_DATA_PROPERTY_STREAM_ID = "stream_id" |
|
DATA_IN_TEXT_DATA_PROPERTY_TEXT = "text" |
|
DATA_IN_TEXT_DATA_PROPERTY_ROLE = "role" |
|
|
|
PROPERTY_CREDENTIALS = "credentials" |
|
PROPERTY_CHANNEL_NAME = "channel_name" |
|
PROPERTY_COLLECTION_NAME = "collection_name" |
|
PROPERTY_TTL = "ttl" |
|
|
|
RETRIEVE_CMD = "retrieve" |
|
CMD_OUT_PROPERTY_RESPONSE = "response" |
|
DOC_EXPIRE_PATH = "expireAt" |
|
DOC_CONTENTS_PATH = "contents" |
|
CONTENT_ROLE_PATH = "role" |
|
CONTENT_TS_PATH = "ts" |
|
CONTENT_STREAM_ID_PATH = "stream_id" |
|
CONTENT_INPUT_PATH = "input" |
|
DEFAULT_TTL = 1 |
|
|
|
|
|
def get_current_time(): |
|
|
|
start_time = datetime.datetime.now() |
|
|
|
unix_microseconds = int(start_time.timestamp() * 1_000_000) |
|
return unix_microseconds |
|
|
|
|
|
def order_by_ts(contents: List[str]) -> List[Any]: |
|
tmp = [] |
|
for c in contents: |
|
tmp.append(json.loads(c)) |
|
sorted_contents = sorted(tmp, key=lambda x: x[CONTENT_TS_PATH]) |
|
res = [] |
|
for sc in sorted_contents: |
|
res.append( |
|
{ |
|
CONTENT_ROLE_PATH: sc[CONTENT_ROLE_PATH], |
|
CONTENT_INPUT_PATH: sc[CONTENT_INPUT_PATH], |
|
CONTENT_STREAM_ID_PATH: sc.get(CONTENT_STREAM_ID_PATH, 0), |
|
} |
|
) |
|
return res |
|
|
|
|
|
@firestore.transactional |
|
def update_in_transaction(transaction, doc_ref, content): |
|
transaction.update(doc_ref, content) |
|
|
|
|
|
@firestore.transactional |
|
def read_in_transaction(transaction, doc_ref): |
|
doc = doc_ref.get(transaction=transaction) |
|
return doc.to_dict() |
|
|
|
|
|
class TSDBFirestoreExtension(Extension): |
|
def __init__(self, name: str): |
|
super().__init__(name) |
|
self.stopped = False |
|
self.thread = None |
|
self.queue = queue.Queue() |
|
self.stopEvent = asyncio.Event() |
|
self.cmd_thread = None |
|
self.loop = None |
|
self.credentials = None |
|
self.channel_name = "" |
|
self.collection_name = "" |
|
self.ttl = DEFAULT_TTL |
|
self.client = None |
|
self.document_ref = None |
|
|
|
self.current_stream_id = 0 |
|
self.cache = "" |
|
|
|
async def __thread_routine(self, ten_env: TenEnv): |
|
ten_env.log_info("__thread_routine start") |
|
self.loop = asyncio.get_running_loop() |
|
ten_env.on_start_done() |
|
await self.stopEvent.wait() |
|
|
|
async def stop_thread(self): |
|
self.stopEvent.set() |
|
|
|
def on_init(self, ten_env: TenEnv) -> None: |
|
ten_env.log_info("TSDBFirestoreExtension on_init") |
|
ten_env.on_init_done() |
|
|
|
def on_start(self, ten_env: TenEnv) -> None: |
|
ten_env.log_info("TSDBFirestoreExtension on_start") |
|
|
|
try: |
|
self.credentials = ten_env.get_property_to_json(PROPERTY_CREDENTIALS) |
|
except Exception as err: |
|
ten_env.log_error( |
|
f"GetProperty required {PROPERTY_CREDENTIALS} failed, err: {err}" |
|
) |
|
return |
|
|
|
try: |
|
self.channel_name = ten_env.get_property_string(PROPERTY_CHANNEL_NAME) |
|
except Exception as err: |
|
ten_env.log_error( |
|
f"GetProperty required {PROPERTY_CHANNEL_NAME} failed, err: {err}" |
|
) |
|
return |
|
|
|
try: |
|
self.collection_name = ten_env.get_property_string(PROPERTY_COLLECTION_NAME) |
|
except Exception as err: |
|
ten_env.log_error( |
|
f"GetProperty required {PROPERTY_COLLECTION_NAME} failed, err: {err}" |
|
) |
|
return |
|
|
|
|
|
cred = credentials.Certificate(json.loads(self.credentials)) |
|
firebase_admin.initialize_app(cred) |
|
self.client = firestore.client() |
|
|
|
self.document_ref = self.client.collection(self.collection_name).document( |
|
self.channel_name |
|
) |
|
|
|
expiration_time = datetime.datetime.now() + datetime.timedelta(days=self.ttl) |
|
exists = self.document_ref.get().exists |
|
if exists: |
|
self.document_ref.update({DOC_EXPIRE_PATH: expiration_time}) |
|
ten_env.log_info( |
|
f"reset document ttl, {self.ttl} day(s), for the channel {self.channel_name}" |
|
) |
|
else: |
|
|
|
self.document_ref.set({DOC_EXPIRE_PATH: expiration_time}) |
|
ten_env.log_info( |
|
f"create new document and set ttl, {self.ttl} day(s), for the channel {self.channel_name}" |
|
) |
|
|
|
|
|
self.thread = threading.Thread(target=self.async_handle, args=[ten_env]) |
|
self.thread.start() |
|
|
|
|
|
self.cmd_thread = threading.Thread( |
|
target=asyncio.run, args=(self.__thread_routine(ten_env),) |
|
) |
|
self.cmd_thread.start() |
|
|
|
def async_handle(self, ten_env: TenEnv) -> None: |
|
while not self.stopped: |
|
try: |
|
value = self.queue.get() |
|
if value is None: |
|
ten_env.log_info("exit handle loop") |
|
break |
|
ts, input_path, role, stream_id = value |
|
content_str = json.dumps( |
|
{ |
|
CONTENT_ROLE_PATH: role, |
|
CONTENT_INPUT_PATH: input_path, |
|
CONTENT_TS_PATH: ts, |
|
CONTENT_STREAM_ID_PATH: stream_id, |
|
} |
|
) |
|
update_in_transaction( |
|
self.client.transaction(), |
|
self.document_ref, |
|
{DOC_CONTENTS_PATH: firestore.ArrayUnion([content_str])}, |
|
) |
|
ten_env.log_info( |
|
f"append {content_str} to firestore document {self.channel_name}" |
|
) |
|
except Exception: |
|
ten_env.log_error("Failed to store chat contents") |
|
|
|
def on_stop(self, ten_env: TenEnv) -> None: |
|
ten_env.log_info("TSDBFirestoreExtension on_stop") |
|
|
|
|
|
self.stopped = True |
|
while not self.queue.empty(): |
|
self.queue.get() |
|
self.queue.put(None) |
|
if self.thread is not None: |
|
self.thread.join() |
|
self.thread = None |
|
|
|
|
|
if self.cmd_thread is not None and self.cmd_thread.is_alive(): |
|
asyncio.run_coroutine_threadsafe(self.stop_thread(), self.loop) |
|
self.cmd_thread.join() |
|
self.cmd_thread = None |
|
|
|
ten_env.on_stop_done() |
|
|
|
def on_deinit(self, ten_env: TenEnv) -> None: |
|
ten_env.log_info("TSDBFirestoreExtension on_deinit") |
|
ten_env.on_deinit_done() |
|
|
|
def on_cmd(self, ten_env: TenEnv, cmd: Cmd) -> None: |
|
try: |
|
cmd_name = cmd.get_name() |
|
ten_env.log_info(f"on_cmd name {cmd_name}") |
|
if cmd_name == RETRIEVE_CMD: |
|
asyncio.run_coroutine_threadsafe(self.retrieve(ten_env, cmd), self.loop) |
|
else: |
|
ten_env.log_info(f"unknown cmd name {cmd_name}") |
|
cmd_result = CmdResult.create(StatusCode.ERROR) |
|
ten_env.return_result(cmd_result, cmd) |
|
except Exception: |
|
ten_env.return_result(CmdResult.create(StatusCode.ERROR), cmd) |
|
|
|
async def retrieve(self, ten_env: TenEnv, cmd: Cmd): |
|
try: |
|
doc_dict = read_in_transaction(self.client.transaction(), self.document_ref) |
|
if DOC_CONTENTS_PATH in doc_dict: |
|
contents = doc_dict[DOC_CONTENTS_PATH] |
|
ten_env.log_info(f"after retrieve {contents}") |
|
ret = CmdResult.create(StatusCode.OK) |
|
ret.set_property_string( |
|
CMD_OUT_PROPERTY_RESPONSE, json.dumps(order_by_ts(contents)) |
|
) |
|
ten_env.return_result(ret, cmd) |
|
else: |
|
ten_env.log_info(f"no contents for the channel {self.channel_name} yet") |
|
ten_env.return_result(CmdResult.create(StatusCode.ERROR), cmd) |
|
except Exception: |
|
ten_env.log_error( |
|
f"Failed to read the document for the channel {self.channel_name}" |
|
) |
|
ten_env.return_result(CmdResult.create(StatusCode.ERROR), cmd) |
|
|
|
def on_data(self, ten_env: TenEnv, data: Data) -> None: |
|
ten_env.log_info("TSDBFirestoreExtension on_data") |
|
|
|
|
|
is_final = False |
|
try: |
|
is_final = data.get_property_bool(DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL) |
|
if not is_final: |
|
ten_env.log_info("ignore non-final input") |
|
return |
|
except Exception as err: |
|
ten_env.log_info( |
|
f"OnData GetProperty {DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL} failed, err: {err}" |
|
) |
|
|
|
stream_id = 0 |
|
try: |
|
stream_id = data.get_property_bool(DATA_IN_TEXT_DATA_PROPERTY_STREAM_ID) |
|
except Exception as err: |
|
ten_env.log_info( |
|
f"OnData GetProperty {DATA_IN_TEXT_DATA_PROPERTY_STREAM_ID} failed, err: {err}" |
|
) |
|
|
|
|
|
try: |
|
input_text = data.get_property_string(DATA_IN_TEXT_DATA_PROPERTY_TEXT) |
|
if not input_text: |
|
ten_env.log_info("ignore empty text") |
|
return |
|
ten_env.log_info(f"OnData input text: [{input_text}]") |
|
except Exception as err: |
|
ten_env.log_info( |
|
f"OnData GetProperty {DATA_IN_TEXT_DATA_PROPERTY_TEXT} failed, err: {err}" |
|
) |
|
return |
|
|
|
try: |
|
role = data.get_property_string(DATA_IN_TEXT_DATA_PROPERTY_ROLE) |
|
if not role: |
|
ten_env.log_warn("ignore empty role") |
|
return |
|
except Exception as err: |
|
ten_env.log_info( |
|
f"OnData GetProperty {DATA_IN_TEXT_DATA_PROPERTY_ROLE} failed, err: {err}" |
|
) |
|
return |
|
|
|
ts = get_current_time() |
|
self.queue.put((ts, input_text, role, stream_id)) |
|
|
|
def on_audio_frame(self, ten_env: TenEnv, audio_frame: AudioFrame) -> None: |
|
pass |
|
|
|
def on_video_frame(self, ten_env: TenEnv, video_frame: VideoFrame) -> None: |
|
pass |
|
|