|
import asyncio |
|
|
|
from ten import TenEnv, Data |
|
|
|
from amazon_transcribe.auth import StaticCredentialResolver |
|
from amazon_transcribe.client import TranscribeStreamingClient |
|
from amazon_transcribe.handlers import TranscriptResultStreamHandler |
|
from amazon_transcribe.model import ( |
|
TranscriptEvent, |
|
TranscriptResultStream, |
|
StartStreamTranscriptionEventStream, |
|
) |
|
|
|
from .transcribe_config import TranscribeConfig |
|
|
|
DATA_OUT_TEXT_DATA_PROPERTY_TEXT = "text" |
|
DATA_OUT_TEXT_DATA_PROPERTY_IS_FINAL = "is_final" |
|
DATA_OUT_TEXT_DATA_PROPERTY_STREAM_ID = "stream_id" |
|
DATA_OUT_TEXT_DATA_PROPERTY_END_OF_SEGMENT = "end_of_segment" |
|
|
|
def create_and_send_data(ten: TenEnv, text_result: str, is_final: bool, stream_id: int = 0): |
|
stable_data = Data.create("text_data") |
|
stable_data.set_property_bool(DATA_OUT_TEXT_DATA_PROPERTY_IS_FINAL, is_final) |
|
stable_data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, text_result) |
|
stable_data.set_property_int(DATA_OUT_TEXT_DATA_PROPERTY_STREAM_ID, stream_id) |
|
stable_data.set_property_bool(DATA_OUT_TEXT_DATA_PROPERTY_END_OF_SEGMENT, is_final) |
|
ten.send_data(stable_data) |
|
|
|
|
|
class AsyncTranscribeWrapper: |
|
def __init__( |
|
self, |
|
config: TranscribeConfig, |
|
queue: asyncio.Queue, |
|
ten: TenEnv, |
|
loop: asyncio.BaseEventLoop, |
|
): |
|
self.queue = queue |
|
self.ten = ten |
|
self.stopped = False |
|
self.config = config |
|
self.loop = loop |
|
self.stream = None |
|
self.handler = None |
|
self.event_handler_task = None |
|
|
|
if config.access_key and config.secret_key: |
|
ten.log_info(f"init trascribe client with access key: {config.access_key}") |
|
self.transcribe_client = TranscribeStreamingClient( |
|
region=config.region, |
|
credential_resolver=StaticCredentialResolver( |
|
access_key_id=config.access_key, secret_access_key=config.secret_key |
|
), |
|
) |
|
else: |
|
ten.log_info( |
|
"init trascribe client without access key, using default credentials provider chain." |
|
) |
|
|
|
self.transcribe_client = TranscribeStreamingClient(region=config.region) |
|
|
|
asyncio.set_event_loop(self.loop) |
|
self.reset_stream() |
|
|
|
def reset_stream(self): |
|
self.stream = None |
|
self.handler = None |
|
self.event_handler_task = None |
|
|
|
async def cleanup(self): |
|
if self.stream: |
|
await self.stream.input_stream.end_stream() |
|
self.ten.log_info("cleanup: stream ended.") |
|
|
|
if self.event_handler_task: |
|
await self.event_handler_task |
|
self.ten.log_info("cleanup: event handler ended.") |
|
|
|
self.reset_stream() |
|
|
|
async def create_stream(self, stream_id) -> bool: |
|
try: |
|
self.stream = await self.get_transcribe_stream() |
|
self.handler = TranscribeEventHandler(self.stream.output_stream, self.ten, stream_id) |
|
self.event_handler_task = asyncio.create_task(self.handler.handle_events()) |
|
except Exception as e: |
|
self.ten.log_error(str(e)) |
|
return False |
|
|
|
return True |
|
|
|
async def send_frame(self) -> None: |
|
while not self.stopped: |
|
try: |
|
pcm_frame = await asyncio.wait_for(self.queue.get(), timeout=3.0) |
|
|
|
if pcm_frame is None: |
|
self.ten.log_warn("send_frame: exit due to None value got.") |
|
return |
|
|
|
frame_buf = pcm_frame.get_buf() |
|
if not frame_buf: |
|
self.ten.log_warn("send_frame: empty pcm_frame detected.") |
|
continue |
|
stream_id = pcm_frame.get_property_int("stream_id") |
|
if not self.stream: |
|
self.ten.log_info("lazy init stream.") |
|
if not await self.create_stream(stream_id): |
|
continue |
|
|
|
await self.stream.input_stream.send_audio_event(audio_chunk=frame_buf) |
|
self.queue.task_done() |
|
except asyncio.TimeoutError: |
|
if self.stream: |
|
await self.cleanup() |
|
self.ten.log_info( |
|
"send_frame: no data for 10s, will close current stream and create a new one when receving new frame." |
|
) |
|
else: |
|
self.ten.log_info("send_frame: waiting for pcm frame.") |
|
except IOError as e: |
|
self.ten.log_error(f"Error in send_frame: {e}") |
|
except Exception as e: |
|
self.ten.log_error(f"Error in send_frame: {e}") |
|
raise e |
|
|
|
self.ten.log_info("send_frame: exit due to self.stopped == True") |
|
|
|
async def transcribe_loop(self) -> None: |
|
try: |
|
await self.send_frame() |
|
except Exception as e: |
|
self.ten.log_error(str(e)) |
|
finally: |
|
await self.cleanup() |
|
|
|
async def get_transcribe_stream(self) -> StartStreamTranscriptionEventStream: |
|
stream = await self.transcribe_client.start_stream_transcription( |
|
language_code=self.config.lang_code, |
|
media_sample_rate_hz=self.config.sample_rate, |
|
media_encoding=self.config.media_encoding, |
|
) |
|
return stream |
|
|
|
def run(self) -> None: |
|
self.loop.run_until_complete(self.transcribe_loop()) |
|
self.loop.close() |
|
self.ten.log_info("async_transcribe_wrapper: thread completed.") |
|
|
|
def stop(self) -> None: |
|
self.stopped = True |
|
|
|
|
|
class TranscribeEventHandler(TranscriptResultStreamHandler): |
|
def __init__(self, transcript_result_stream: TranscriptResultStream, ten: TenEnv, stream_id: int = 0): |
|
super().__init__(transcript_result_stream) |
|
self.ten = ten |
|
self.stream_id = stream_id |
|
|
|
async def handle_transcript_event(self, transcript_event: TranscriptEvent) -> None: |
|
results = transcript_event.transcript.results |
|
text_result = "" |
|
|
|
is_final = True |
|
|
|
for result in results: |
|
if result.is_partial: |
|
is_final = False |
|
|
|
|
|
for alt in result.alternatives: |
|
text_result += alt.transcript |
|
|
|
if not text_result: |
|
return |
|
|
|
self.ten.log_info(f"got transcript: [{text_result}], is_final: [{is_final}]") |
|
|
|
create_and_send_data(ten=self.ten, text_result=text_result, is_final=is_final, stream_id=self.stream_id) |
|
|