|
|
|
|
|
|
|
|
|
|
|
|
|
import asyncio |
|
import time |
|
import traceback |
|
from enum import Enum |
|
from typing import Optional, List, Dict |
|
|
|
import boto3 |
|
from ten import ( |
|
AsyncTenEnv, |
|
Cmd, |
|
StatusCode, |
|
CmdResult, |
|
Data, |
|
) |
|
from ten_ai_base.config import BaseConfig |
|
from ten_ai_base.llm import AsyncLLMBaseExtension |
|
from dataclasses import dataclass |
|
|
|
from .utils import ( |
|
rgb2base64jpeg, |
|
filter_images, |
|
parse_sentence, |
|
get_greeting_text, |
|
merge_images |
|
) |
|
|
|
|
|
MAX_IMAGE_COUNT = 20 |
|
ONE_BATCH_SEND_COUNT = 6 |
|
VIDEO_FRAME_INTERVAL = 0.5 |
|
|
|
|
|
CMD_IN_FLUSH = "flush" |
|
CMD_IN_ON_USER_JOINED = "on_user_joined" |
|
CMD_IN_ON_USER_LEFT = "on_user_left" |
|
CMD_OUT_FLUSH = "flush" |
|
|
|
|
|
DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL = "is_final" |
|
DATA_IN_TEXT_DATA_PROPERTY_TEXT = "text" |
|
DATA_OUT_TEXT_DATA_PROPERTY_TEXT = "text" |
|
DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT = "end_of_segment" |
|
|
|
class Role(str, Enum): |
|
"""Role definitions for chat participants.""" |
|
User = "user" |
|
Assistant = "assistant" |
|
|
|
@dataclass |
|
class BedrockLLMConfig(BaseConfig): |
|
"""Configuration for BedrockV2V extension.""" |
|
region: str = "us-east-1" |
|
model_id: str = "us.amazon.nova-lite-v1:0" |
|
access_key_id: str = "" |
|
secret_access_key: str = "" |
|
language: str = "en-US" |
|
prompt: str = "You are an intelligent assistant with real-time interaction capabilities. You will be presented with a series of images that represent a video sequence. Describe what you see directly, as if you were observing the scene in real-time. Do not mention that you are looking at images or a video. Instead, narrate the scene and actions as they unfold. Engage in conversation with the user based on this visual input and their questions, maintaining a concise and clear." |
|
temperature: float = 0.7 |
|
max_tokens: int = 256 |
|
tokP: str = 0.5 |
|
topK: str = 10 |
|
max_duration: int = 30 |
|
vendor: str = "" |
|
stream_id: int = 0 |
|
dump: bool = False |
|
max_memory_length: int = 10 |
|
is_memory_enabled: bool = False |
|
is_enable_video: bool = False |
|
greeting: str = "Hello, I'm here to help you. How can I assist you today?" |
|
|
|
def build_ctx(self) -> dict: |
|
"""Build context dictionary from configuration.""" |
|
return { |
|
"language": self.language, |
|
"model": self.model_id, |
|
} |
|
|
|
class BedrockLLMExtension(AsyncLLMBaseExtension): |
|
"""Extension for handling video-to-video processing using AWS Bedrock.""" |
|
|
|
def __init__(self, name: str): |
|
super().__init__(name) |
|
self.config: Optional[BedrockLLMConfig] = None |
|
self.stopped: bool = False |
|
self.memory: list = [] |
|
self.users_count: int = 0 |
|
self.bedrock_client = None |
|
self.image_buffers: list = [] |
|
self.image_queue = asyncio.Queue() |
|
self.text_buffer: str = "" |
|
self.input_start_time: float = 0 |
|
self.processing_times = [] |
|
self.ten_env = None |
|
self.ctx = None |
|
|
|
async def on_init(self, ten_env: AsyncTenEnv) -> None: |
|
"""Initialize the extension.""" |
|
await super().on_init(ten_env) |
|
ten_env.log_info("BedrockV2VExtension initialized") |
|
|
|
async def on_start(self, ten_env: AsyncTenEnv) -> None: |
|
"""Start the extension and set up required components.""" |
|
await super().on_start(ten_env) |
|
ten_env.log_info("BedrockV2VExtension starting") |
|
|
|
try: |
|
self.config = await BedrockLLMConfig.create_async(ten_env=ten_env) |
|
ten_env.log_info(f"Configuration: {self.config}") |
|
|
|
if not self.config.access_key_id or not self.config.secret_access_key: |
|
ten_env.log_error("AWS credentials (access_key_id and secret_access_key) are required") |
|
return |
|
|
|
await self._setup_components(ten_env) |
|
|
|
except Exception as e: |
|
traceback.print_exc() |
|
ten_env.log_error(f"Failed to initialize: {e}") |
|
|
|
async def _setup_components(self, ten_env: AsyncTenEnv) -> None: |
|
"""Set up extension components.""" |
|
self.memory = [] |
|
self.ctx = self.config.build_ctx() |
|
self.ten_env = ten_env |
|
|
|
self.loop = asyncio.get_event_loop() |
|
self.loop.create_task(self._on_video(ten_env)) |
|
|
|
async def on_stop(self, ten_env: AsyncTenEnv) -> None: |
|
"""Stop the extension.""" |
|
await super().on_stop(ten_env) |
|
ten_env.log_info("BedrockV2VExtension stopping") |
|
self.stopped = True |
|
|
|
async def on_data(self, ten_env: AsyncTenEnv, data) -> None: |
|
"""Handle incoming data.""" |
|
ten_env.log_info("on_data receive begin...") |
|
data_name = data.get_name() |
|
ten_env.log_info(f"on_data name {data_name}") |
|
|
|
try: |
|
is_final = data.get_property_bool(DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL) |
|
input_text = data.get_property_string(DATA_IN_TEXT_DATA_PROPERTY_TEXT) |
|
|
|
if not is_final: |
|
ten_env.log_info("ignore non-final input") |
|
return |
|
|
|
if not input_text: |
|
ten_env.log_info("ignore empty text") |
|
return |
|
|
|
ten_env.log_info(f"OnData input text: [{input_text}]") |
|
self.text_buffer = input_text |
|
await self._handle_input_truncation("is_final") |
|
|
|
except Exception as err: |
|
ten_env.log_info(f"Error processing data: {err}") |
|
|
|
async def on_video_frame(self, _: AsyncTenEnv, video_frame) -> None: |
|
"""Handle incoming video frames.""" |
|
if not self.config.is_enable_video: |
|
return |
|
image_data = video_frame.get_buf() |
|
image_width = video_frame.get_width() |
|
image_height = video_frame.get_height() |
|
await self.image_queue.put([image_data, image_width, image_height]) |
|
|
|
async def _on_video(self, ten_env: AsyncTenEnv): |
|
"""Process video frames from the queue.""" |
|
while True: |
|
try: |
|
[image_data, image_width, image_height] = await self.image_queue.get() |
|
|
|
|
|
|
|
frame_buffer = rgb2base64jpeg(image_data, image_width, image_height) |
|
|
|
self.image_buffers.append(frame_buffer) |
|
|
|
|
|
|
|
while len(self.image_buffers) > MAX_IMAGE_COUNT: |
|
self.image_buffers.pop(0) |
|
|
|
|
|
while not self.image_queue.empty(): |
|
await self.image_queue.get() |
|
|
|
await asyncio.sleep(VIDEO_FRAME_INTERVAL) |
|
|
|
except Exception as e: |
|
traceback.print_exc() |
|
ten_env.log_error(f"Error processing video frame: {e}") |
|
|
|
async def on_cmd(self, ten_env: AsyncTenEnv, cmd: Cmd) -> None: |
|
"""Handle incoming commands.""" |
|
cmd_name = cmd.get_name() |
|
ten_env.log_info(f"Command received: {cmd_name}") |
|
|
|
try: |
|
if cmd_name == CMD_IN_FLUSH: |
|
await ten_env.send_cmd(Cmd.create(CMD_OUT_FLUSH)) |
|
elif cmd_name == CMD_IN_ON_USER_JOINED: |
|
await self._handle_user_joined() |
|
elif cmd_name == CMD_IN_ON_USER_LEFT: |
|
self.users_count -= 1 |
|
else: |
|
await super().on_cmd(ten_env, cmd) |
|
return |
|
|
|
cmd_result = CmdResult.create(StatusCode.OK) |
|
cmd_result.set_property_string("detail", "success") |
|
await ten_env.return_result(cmd_result, cmd) |
|
|
|
except Exception as e: |
|
traceback.print_exc() |
|
ten_env.log_error(f"Error handling command {cmd_name}: {e}") |
|
cmd_result = CmdResult.create(StatusCode.ERROR) |
|
cmd_result.set_property_string("detail", str(e)) |
|
await ten_env.return_result(cmd_result, cmd) |
|
async def _handle_user_left(self) -> None: |
|
"""Handle user left event.""" |
|
self.users_count -= 1 |
|
if self.users_count == 0: |
|
self._reset_state() |
|
|
|
if self.users_count < 0: |
|
self.users_count = 0 |
|
async def _handle_user_joined(self) -> None: |
|
"""Handle user joined event.""" |
|
self.users_count += 1 |
|
if self.users_count == 1: |
|
await self._greeting() |
|
|
|
async def _handle_input_truncation(self, reason: str): |
|
"""Handle input truncation events.""" |
|
try: |
|
self.ten_env.log_info(f"Input truncated due to: {reason}") |
|
|
|
if self.text_buffer: |
|
await self._call_nova_model(self.text_buffer, self.image_buffers) |
|
|
|
self._reset_state() |
|
|
|
except Exception as e: |
|
traceback.print_exc() |
|
self.ten_env.log_error(f"Error handling input truncation: {e}") |
|
|
|
def _reset_state(self): |
|
"""Reset internal state.""" |
|
self.text_buffer = "" |
|
self.image_buffers = [] |
|
self.input_start_time = 0 |
|
|
|
async def _initialize_aws_clients(self): |
|
"""Initialize AWS clients.""" |
|
try: |
|
if not self.bedrock_client: |
|
self.bedrock_client = boto3.client('bedrock-runtime', |
|
aws_access_key_id=self.config.access_key_id, |
|
aws_secret_access_key=self.config.secret_access_key, |
|
region_name=self.config.region |
|
) |
|
except Exception as e: |
|
traceback.print_exc() |
|
self.ten_env.log_error(f"Error initializing AWS clients: {e}") |
|
raise |
|
|
|
async def _greeting(self) -> None: |
|
"""Send greeting message to the user.""" |
|
if self.users_count == 1: |
|
text = self.config.greeting or get_greeting_text(self.config.language) |
|
self.ten_env.log_info(f"send greeting {text}") |
|
await self._send_text_data(text, True, Role.Assistant) |
|
|
|
async def _send_text_data(self, text: str, end_of_segment: bool, role: Role): |
|
"""Send text data to the user.""" |
|
try: |
|
d = Data.create("text_data") |
|
d.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, text) |
|
d.set_property_bool(DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, end_of_segment) |
|
d.set_property_string("role", role) |
|
asyncio.create_task(self.ten_env.send_data(d)) |
|
except Exception as e: |
|
self.ten_env.log_error(f"Error sending text data: {e}") |
|
|
|
async def _call_nova_model(self, input_text: str, image_buffers: List[bytes]) -> None: |
|
"""Call Bedrock's Nova model with text and video input.""" |
|
try: |
|
if not self.bedrock_client: |
|
await self._initialize_aws_clients() |
|
|
|
if not input_text: |
|
self.ten_env.log_info("Text input is empty") |
|
return |
|
|
|
contents = [] |
|
|
|
|
|
if image_buffers: |
|
filtered_buffers = filter_images(image_buffers, ONE_BATCH_SEND_COUNT) |
|
for image_data in filtered_buffers: |
|
contents.append({ |
|
"image": { |
|
"format": 'jpeg', |
|
"source": { |
|
"bytes": image_data |
|
} |
|
} |
|
}) |
|
|
|
while len(self.memory) > self.config.max_memory_length: |
|
self.memory.pop(0) |
|
while len(self.memory) > 0 and self.memory[0]["role"] == "assistant": |
|
self.memory.pop(0) |
|
while len(self.memory) > 0 and self.memory[-1]["role"] == "user": |
|
self.memory.pop(-1) |
|
|
|
|
|
contents.append({"text": input_text}) |
|
messages = [] |
|
for m in self.memory: |
|
|
|
m_content = m["content"] |
|
if isinstance(m_content, str): |
|
m_content = [{"text": m_content}] |
|
messages.append({ |
|
"role": m["role"], |
|
"content": m_content |
|
}) |
|
messages.append({ |
|
"role": "user", |
|
"content": contents |
|
}) |
|
|
|
inf_params = { |
|
"maxTokens": self.config.max_tokens, |
|
"topP": self.config.tokP, |
|
"temperature": self.config.temperature |
|
} |
|
|
|
additional_config = { |
|
"inferenceConfig": { |
|
"topK": self.config.topK |
|
} |
|
} |
|
|
|
system = [{ |
|
"text": self.config.prompt |
|
}] |
|
|
|
|
|
start_time = time.time() |
|
response = self.bedrock_client.converse_stream( |
|
modelId=self.config.model_id, |
|
system=system, |
|
messages=messages, |
|
inferenceConfig=inf_params, |
|
additionalModelRequestFields=additional_config, |
|
) |
|
full_content = await self._process_stream_response(response, start_time) |
|
|
|
async def async_append_memory(): |
|
if not self.config.is_memory_enabled: |
|
return |
|
image = merge_images(image_buffers) |
|
contents = [] |
|
if image: |
|
contents.append({ |
|
"image": { |
|
"format": 'jpeg', |
|
"source": { |
|
"bytes": image |
|
} |
|
} |
|
}) |
|
contents.append({"text": input_text}) |
|
self.memory.append({"role": Role.User, "content": contents}) |
|
self.memory.append({"role": Role.Assistant, "content": [{"text": full_content}]}) |
|
|
|
asyncio.create_task(async_append_memory()) |
|
except Exception as e: |
|
traceback.print_exc() |
|
self.ten_env.log_error(f"Error calling Nova model: {e}") |
|
|
|
async def _process_stream_response(self, response: Dict, start_time: float): |
|
"""Process streaming response from Nova model.""" |
|
sentence = "" |
|
full_content = "" |
|
first_sentence_sent = False |
|
|
|
for event in response.get('stream'): |
|
if "contentBlockDelta" in event: |
|
if "text" in event["contentBlockDelta"]["delta"]: |
|
content = event["contentBlockDelta"]["delta"]["text"] |
|
full_content += content |
|
|
|
while True: |
|
sentence, content, sentence_is_final = parse_sentence(sentence, content) |
|
if not sentence or not sentence_is_final: |
|
break |
|
|
|
self.ten_env.log_info(f"Processing sentence: [{sentence}]") |
|
await self._send_text_data(sentence, False, Role.Assistant) |
|
|
|
if not first_sentence_sent: |
|
first_sentence_sent = True |
|
self.ten_env.log_info(f"First sentence latency: {(time.time() - start_time)*1000}ms") |
|
|
|
sentence = "" |
|
|
|
elif any(key in event for key in ["internalServerException", "modelStreamErrorException", |
|
"throttlingException", "validationException"]): |
|
self.ten_env.log_error(f"Stream error: {event}") |
|
break |
|
|
|
elif 'metadata' in event: |
|
if 'metrics' in event['metadata']: |
|
self.ten_env.log_info(f"Nova model latency: {event['metadata']['metrics']['latencyMs']}ms") |
|
|
|
|
|
await self._send_text_data(sentence, True, Role.Assistant) |
|
self.ten_env.log_info(f"Final sentence sent: [{sentence}]") |
|
|
|
self.processing_times.append(time.time() - start_time) |
|
return full_content |
|
|
|
async def on_call_chat_completion(self, async_ten_env, **kargs): |
|
raise NotImplementedError |
|
|
|
async def on_data_chat_completion(self, async_ten_env, **kargs): |
|
raise NotImplementedError |
|
|
|
async def on_tools_update( |
|
self, ten_env: AsyncTenEnv, tool |
|
) -> None: |
|
"""Called when a new tool is registered. Implement this method to process the new tool.""" |
|
ten_env.log_info(f"on tools update {tool}") |
|
|