3v324v23's picture
Зафиксирована рабочая версия TEN-Agent для HuggingFace Space
87337b1
import asyncio
from ten import TenEnv, Data
from amazon_transcribe.auth import StaticCredentialResolver
from amazon_transcribe.client import TranscribeStreamingClient
from amazon_transcribe.handlers import TranscriptResultStreamHandler
from amazon_transcribe.model import (
TranscriptEvent,
TranscriptResultStream,
StartStreamTranscriptionEventStream,
)
from .transcribe_config import TranscribeConfig
DATA_OUT_TEXT_DATA_PROPERTY_TEXT = "text"
DATA_OUT_TEXT_DATA_PROPERTY_IS_FINAL = "is_final"
DATA_OUT_TEXT_DATA_PROPERTY_STREAM_ID = "stream_id"
DATA_OUT_TEXT_DATA_PROPERTY_END_OF_SEGMENT = "end_of_segment"
def create_and_send_data(ten: TenEnv, text_result: str, is_final: bool, stream_id: int = 0):
stable_data = Data.create("text_data")
stable_data.set_property_bool(DATA_OUT_TEXT_DATA_PROPERTY_IS_FINAL, is_final)
stable_data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, text_result)
stable_data.set_property_int(DATA_OUT_TEXT_DATA_PROPERTY_STREAM_ID, stream_id)
stable_data.set_property_bool(DATA_OUT_TEXT_DATA_PROPERTY_END_OF_SEGMENT, is_final)
ten.send_data(stable_data)
class AsyncTranscribeWrapper:
def __init__(
self,
config: TranscribeConfig,
queue: asyncio.Queue,
ten: TenEnv,
loop: asyncio.BaseEventLoop,
):
self.queue = queue
self.ten = ten
self.stopped = False
self.config = config
self.loop = loop
self.stream = None
self.handler = None
self.event_handler_task = None
if config.access_key and config.secret_key:
ten.log_info(f"init trascribe client with access key: {config.access_key}")
self.transcribe_client = TranscribeStreamingClient(
region=config.region,
credential_resolver=StaticCredentialResolver(
access_key_id=config.access_key, secret_access_key=config.secret_key
),
)
else:
ten.log_info(
"init trascribe client without access key, using default credentials provider chain."
)
self.transcribe_client = TranscribeStreamingClient(region=config.region)
asyncio.set_event_loop(self.loop)
self.reset_stream()
def reset_stream(self):
self.stream = None
self.handler = None
self.event_handler_task = None
async def cleanup(self):
if self.stream:
await self.stream.input_stream.end_stream()
self.ten.log_info("cleanup: stream ended.")
if self.event_handler_task:
await self.event_handler_task
self.ten.log_info("cleanup: event handler ended.")
self.reset_stream()
async def create_stream(self, stream_id) -> bool:
try:
self.stream = await self.get_transcribe_stream()
self.handler = TranscribeEventHandler(self.stream.output_stream, self.ten, stream_id)
self.event_handler_task = asyncio.create_task(self.handler.handle_events())
except Exception as e:
self.ten.log_error(str(e))
return False
return True
async def send_frame(self) -> None:
while not self.stopped:
try:
pcm_frame = await asyncio.wait_for(self.queue.get(), timeout=3.0)
if pcm_frame is None:
self.ten.log_warn("send_frame: exit due to None value got.")
return
frame_buf = pcm_frame.get_buf()
if not frame_buf:
self.ten.log_warn("send_frame: empty pcm_frame detected.")
continue
stream_id = pcm_frame.get_property_int("stream_id")
if not self.stream:
self.ten.log_info("lazy init stream.")
if not await self.create_stream(stream_id):
continue
await self.stream.input_stream.send_audio_event(audio_chunk=frame_buf)
self.queue.task_done()
except asyncio.TimeoutError:
if self.stream:
await self.cleanup()
self.ten.log_info(
"send_frame: no data for 10s, will close current stream and create a new one when receving new frame."
)
else:
self.ten.log_info("send_frame: waiting for pcm frame.")
except IOError as e:
self.ten.log_error(f"Error in send_frame: {e}")
except Exception as e:
self.ten.log_error(f"Error in send_frame: {e}")
raise e
self.ten.log_info("send_frame: exit due to self.stopped == True")
async def transcribe_loop(self) -> None:
try:
await self.send_frame()
except Exception as e:
self.ten.log_error(str(e))
finally:
await self.cleanup()
async def get_transcribe_stream(self) -> StartStreamTranscriptionEventStream:
stream = await self.transcribe_client.start_stream_transcription(
language_code=self.config.lang_code,
media_sample_rate_hz=self.config.sample_rate,
media_encoding=self.config.media_encoding,
)
return stream
def run(self) -> None:
self.loop.run_until_complete(self.transcribe_loop())
self.loop.close()
self.ten.log_info("async_transcribe_wrapper: thread completed.")
def stop(self) -> None:
self.stopped = True
class TranscribeEventHandler(TranscriptResultStreamHandler):
def __init__(self, transcript_result_stream: TranscriptResultStream, ten: TenEnv, stream_id: int = 0):
super().__init__(transcript_result_stream)
self.ten = ten
self.stream_id = stream_id
async def handle_transcript_event(self, transcript_event: TranscriptEvent) -> None:
results = transcript_event.transcript.results
text_result = ""
is_final = True
for result in results:
if result.is_partial:
is_final = False
# continue
for alt in result.alternatives:
text_result += alt.transcript
if not text_result:
return
self.ten.log_info(f"got transcript: [{text_result}], is_final: [{is_final}]")
create_and_send_data(ten=self.ten, text_result=text_result, is_final=is_final, stream_id=self.stream_id)