File size: 5,920 Bytes
87337b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
from ten import (
AsyncExtension,
AsyncTenEnv,
Cmd,
Data,
AudioFrame,
StatusCode,
CmdResult,
)
import asyncio
from deepgram import (
AsyncListenWebSocketClient,
DeepgramClientOptions,
LiveTranscriptionEvents,
LiveOptions,
)
from dataclasses import dataclass
from ten_ai_base.config import BaseConfig
DATA_OUT_TEXT_DATA_PROPERTY_TEXT = "text"
DATA_OUT_TEXT_DATA_PROPERTY_IS_FINAL = "is_final"
DATA_OUT_TEXT_DATA_PROPERTY_STREAM_ID = "stream_id"
DATA_OUT_TEXT_DATA_PROPERTY_END_OF_SEGMENT = "end_of_segment"
@dataclass
class DeepgramASRConfig(BaseConfig):
api_key: str = ""
language: str = "en-US"
model: str = "nova-2"
sample_rate: int = 16000
channels: int = 1
encoding: str = "linear16"
interim_results: bool = True
punctuate: bool = True
class DeepgramASRExtension(AsyncExtension):
def __init__(self, name: str):
super().__init__(name)
self.stopped = False
self.connected = False
self.client: AsyncListenWebSocketClient = None
self.config: DeepgramASRConfig = None
self.ten_env: AsyncTenEnv = None
self.loop = None
self.stream_id = -1
async def on_init(self, ten_env: AsyncTenEnv) -> None:
ten_env.log_info("DeepgramASRExtension on_init")
async def on_start(self, ten_env: AsyncTenEnv) -> None:
ten_env.log_info("on_start")
self.loop = asyncio.get_event_loop()
self.ten_env = ten_env
self.config = await DeepgramASRConfig.create_async(ten_env=ten_env)
ten_env.log_info(f"config: {self.config}")
if not self.config.api_key:
ten_env.log_error("get property api_key")
return
self.loop.create_task(self._start_listen())
ten_env.log_info("starting async_deepgram_wrapper thread")
async def on_audio_frame(self, _: AsyncTenEnv, frame: AudioFrame) -> None:
frame_buf = frame.get_buf()
if not frame_buf:
self.ten_env.log_warn("send_frame: empty pcm_frame detected.")
return
if not self.connected:
self.ten_env.log_debug("send_frame: deepgram not connected.")
return
self.stream_id = frame.get_property_int("stream_id")
if self.client:
await self.client.send(frame_buf)
async def on_stop(self, ten_env: AsyncTenEnv) -> None:
ten_env.log_info("on_stop")
self.stopped = True
if self.client:
await self.client.finish()
async def on_cmd(self, ten_env: AsyncTenEnv, cmd: Cmd) -> None:
cmd_json = cmd.to_json()
ten_env.log_info(f"on_cmd json: {cmd_json}")
cmd_result = CmdResult.create(StatusCode.OK)
cmd_result.set_property_string("detail", "success")
await ten_env.return_result(cmd_result, cmd)
async def _start_listen(self) -> None:
self.ten_env.log_info("start and listen deepgram")
self.client = AsyncListenWebSocketClient(
config=DeepgramClientOptions(
api_key=self.config.api_key, options={"keepalive": "true"}
)
)
async def on_open(_, event):
self.ten_env.log_info(f"deepgram event callback on_open: {event}")
self.connected = True
async def on_close(_, event):
self.ten_env.log_info(f"deepgram event callback on_close: {event}")
self.connected = False
if not self.stopped:
self.ten_env.log_warn(
"Deepgram connection closed unexpectedly. Reconnecting..."
)
await asyncio.sleep(0.2)
self.loop.create_task(self._start_listen())
async def on_message(_, result):
sentence = result.channel.alternatives[0].transcript
if len(sentence) == 0:
return
is_final = result.is_final
self.ten_env.log_info(
f"deepgram got sentence: [{sentence}], is_final: {is_final}, stream_id: {self.stream_id}"
)
await self._send_text(
text=sentence, is_final=is_final, stream_id=self.stream_id
)
async def on_error(_, error):
self.ten_env.log_error(f"deepgram event callback on_error: {error}")
self.client.on(LiveTranscriptionEvents.Open, on_open)
self.client.on(LiveTranscriptionEvents.Close, on_close)
self.client.on(LiveTranscriptionEvents.Transcript, on_message)
self.client.on(LiveTranscriptionEvents.Error, on_error)
options = LiveOptions(
language=self.config.language,
model=self.config.model,
sample_rate=self.config.sample_rate,
channels=self.config.channels,
encoding=self.config.encoding,
interim_results=self.config.interim_results,
punctuate=self.config.punctuate,
)
self.ten_env.log_info(f"deepgram options: {options}")
# connect to websocket
result = await self.client.start(options)
if not result:
self.ten_env.log_error("failed to connect to deepgram")
await asyncio.sleep(0.2)
self.loop.create_task(self._start_listen())
else:
self.ten_env.log_info("successfully connected to deepgram")
async def _send_text(self, text: str, is_final: bool, stream_id: str) -> None:
stable_data = Data.create("text_data")
stable_data.set_property_bool(DATA_OUT_TEXT_DATA_PROPERTY_IS_FINAL, is_final)
stable_data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, text)
stable_data.set_property_int(DATA_OUT_TEXT_DATA_PROPERTY_STREAM_ID, stream_id)
stable_data.set_property_bool(
DATA_OUT_TEXT_DATA_PROPERTY_END_OF_SEGMENT, is_final
)
asyncio.create_task(self.ten_env.send_data(stable_data))
|