import uuid from dataclasses import dataclass from pathlib import Path from typing import Literal import gradio as gr import numpy as np import rerun as rr from gradio_rerun.events import SelectionChange from typing_extensions import TypedDict def get_recording(recording_id) -> rr.RecordingStream: return rr.RecordingStream(application_id="multiview_sam_annotate", recording_id=recording_id) class RerunLogPaths(TypedDict): timeline_name: str parent_log_path: Path cam_log_path_list: list[Path] @dataclass class KeypointsContainer: """Container for include and exclude keypoints""" include_points: np.ndarray # shape (n,2) exclude_points: np.ndarray # shape (m,2) @classmethod def empty(cls) -> "KeypointsContainer": """Create an empty keypoints container""" return cls(include_points=np.zeros((0, 2), dtype=float), exclude_points=np.zeros((0, 2), dtype=float)) def add_point(self, point: tuple[float, float], label: Literal["include", "exclude"]) -> None: """Add a point with the specified label""" point_array = np.array([point], dtype=float) if label == "include": self.include_points = ( np.vstack([self.include_points, point_array]) if self.include_points.shape[0] > 0 else point_array ) else: self.exclude_points = ( np.vstack([self.exclude_points, point_array]) if self.exclude_points.shape[0] > 0 else point_array ) def clear(self) -> None: """Clear all points""" self.include_points = np.zeros((0, 2), dtype=float) self.exclude_points = np.zeros((0, 2), dtype=float) # In this function, the `request` and `evt` parameters will be automatically injected by Gradio when this event listener is fired. # # `SelectionChange` is a subclass of `EventData`: https://www.gradio.app/docs/gradio/eventdata # `gr.Request`: https://www.gradio.app/main/docs/gradio/request def update_keypoints( active_recording_id: uuid.UUID, point_type: Literal["include", "exclude"], mv_keypoint_dict: dict[str, KeypointsContainer], log_paths: RerunLogPaths, request: gr.Request, change: SelectionChange, ): if active_recording_id == "": return evt = change.payload # We can only log a keypoint if the user selected only a single item. if len(evt.items) != 1: return item = evt.items[0] # If the selected item isn't an entity, or we don't have its position, then bail out. if item.type != "entity" or item.position is None: return # Now we can produce a valid keypoint. rec: rr.RecordingStream = get_recording(active_recording_id) stream: rr.BinaryStream = rec.binary_stream() current_keypoint: tuple[int, int] = item.position[0:2] for cam_name in mv_keypoint_dict: if cam_name in item.entity_path: # Update the keypoints for the specific camera mv_keypoint_dict[cam_name].add_point(current_keypoint, point_type) current_keypoint_container: KeypointsContainer = mv_keypoint_dict[cam_name] rec.set_time_nanos(log_paths["timeline_name"], nanos=0) # Log include points if any exist if current_keypoint_container.include_points.shape[0] > 0: rec.log( f"{item.entity_path}/include", rr.Points2D(current_keypoint_container.include_points, colors=(0, 255, 0), radii=5), ) # Log exclude points if any exist if current_keypoint_container.exclude_points.shape[0] > 0: rec.log( f"{item.entity_path}/exclude", rr.Points2D(current_keypoint_container.exclude_points, colors=(255, 0, 0), radii=5), ) # # Ensure we consume everything from the recording. stream.flush() yield stream.read(), mv_keypoint_dict