import json import re from pathlib import Path from typing import List import json_repair from omagent_core.models.llms.base import BaseLLMBackend from omagent_core.models.llms.prompt import PromptTemplate from omagent_core.tool_system.base import ArgSchema, BaseTool from omagent_core.utils.logger import logging from omagent_core.utils.registry import registry from pydantic import Field from scenedetect import FrameTimecode from ...misc.scene import VideoScenes CURRENT_PATH = Path(__file__).parents[0] ARGSCHEMA = { "start_time": { "type": "number", "description": "Start time (in seconds) of the video to extract frames from.", "required": True, }, "end_time": { "type": "number", "description": "End time (in seconds) of the video to extract frames from.", "required": True, }, "number": { "type": "number", "description": "Number of frames of extraction. More frames means more details but more cost. Do not exceed 10.", "required": True, }, } @registry.register_tool() class Rewinder(BaseTool, BaseLLMBackend): args_schema: ArgSchema = ArgSchema(**ARGSCHEMA) description: str = ( "Rollback and extract frames from video which is already loaded to get more specific details for further analysis." ) prompts: List[PromptTemplate] = Field( default=[ PromptTemplate.from_file( CURRENT_PATH.joinpath("rewinder_sys_prompt.prompt"), role="system", ), PromptTemplate.from_file( CURRENT_PATH.joinpath("rewinder_user_prompt.prompt"), role="user", ), ] ) def _run( self, start_time: float = 0.0, end_time: float = None, number: int = 1 ) -> str: if self.stm(self.workflow_instance_id).get("video", None) is None: raise ValueError("No video is loaded.") else: video: VideoScenes = VideoScenes.from_serializable( self.stm(self.workflow_instance_id)["video"] ) if number > 10: logging.warning("Number of frames exceeds 10. Will extract 10 frames.") number = 10 start = FrameTimecode(timecode=start_time, fps=video.stream.frame_rate) if end_time is None: end = video.stream.duration else: end = FrameTimecode(timecode=end_time, fps=video.stream.frame_rate) if start_time == end_time: frames, time_stamps = video.get_video_frames( (start, end + 1), video.stream.frame_rate ) else: interval = int((end.get_frames() - start.get_frames()) / number) frames, time_stamps = video.get_video_frames((start, end), interval) # self.stm.image_cache.clear() payload = [] for i, (frame, time_stamp) in enumerate(zip(frames, time_stamps)): payload.append(f"timestamp_{time_stamp}") payload.append(frame) res = self.infer(input_list=[{"timestamp_with_images": payload}])[0]["choices"][ 0 ]["message"]["content"] image_contents = json_repair.loads(res) self.stm(self.workflow_instance_id)["image_cache"] = {} return f"extracted_frames described as: {image_contents}." async def _arun( self, start_time: float = 0.0, end_time: float = None, number: int = 1 ) -> str: return self._run(start_time, end_time, number=number)