pablovela5620 commited on
Commit
be9b1db
·
verified ·
1 Parent(s): 5528471

Upload folder using huggingface_hub

Browse files
annotation_example/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # In the "my_package.__init__" submodule:
2
+ from beartype import BeartypeConf
3
+ from beartype.claw import beartype_this_package
4
+
5
+ beartype_this_package(
6
+ conf=BeartypeConf(
7
+ claw_skip_package_names=("annotation_example.gradio_ui.callbacks", "annotation_example.gradio_ui.sv_sam")
8
+ )
9
+ )
annotation_example/__pycache__/__init__.cpython-311.opt-beartype0v20v2.pyc ADDED
Binary file (598 Bytes). View file
 
annotation_example/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (540 Bytes). View file
 
annotation_example/__pycache__/op.cpython-311.opt-beartype0v20v2.pyc ADDED
Binary file (8.39 kB). View file
 
annotation_example/__pycache__/op.cpython-311.pyc ADDED
Binary file (4.76 kB). View file
 
annotation_example/gradio_ui/__pycache__/callbacks.cpython-311.pyc ADDED
Binary file (5.5 kB). View file
 
annotation_example/gradio_ui/__pycache__/mv_sam.cpython-311.opt-beartype0v20v2.pyc ADDED
Binary file (50.6 kB). View file
 
annotation_example/gradio_ui/__pycache__/mv_sam.cpython-311.pyc ADDED
Binary file (35.4 kB). View file
 
annotation_example/gradio_ui/__pycache__/mv_sam_callbacks.cpython-311.opt-beartype0v20v2.pyc ADDED
Binary file (7.31 kB). View file
 
annotation_example/gradio_ui/__pycache__/sv_sam.cpython-311.opt-beartype0v20v2.pyc ADDED
Binary file (36.8 kB). View file
 
annotation_example/gradio_ui/__pycache__/sv_sam.cpython-311.pyc ADDED
Binary file (26.5 kB). View file
 
annotation_example/gradio_ui/__pycache__/vggt_sam.cpython-311.pyc ADDED
Binary file (26.2 kB). View file
 
annotation_example/gradio_ui/mv_sam.py ADDED
@@ -0,0 +1,670 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ import uuid
3
+ from pathlib import Path
4
+ from typing import Literal, assert_never, no_type_check
5
+
6
+ import cv2
7
+ import gradio as gr
8
+ import numpy as np
9
+ import open3d as o3d
10
+ import rerun as rr
11
+ import rerun.blueprint as rrb
12
+ import torch
13
+ from einops import rearrange
14
+ from gradio_rerun import Rerun
15
+ from jaxtyping import Bool, Float, Float32, Int, UInt8, UInt16
16
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
17
+ from sam2.sam2_video_predictor import SAM2VideoPredictor
18
+ from simplecv.camera_parameters import PinholeParameters
19
+ from simplecv.conversion_utils import save_to_nerfstudio
20
+ from simplecv.data.exoego.assembly_101 import Assembely101Sequence
21
+ from simplecv.data.exoego.hocap import ExoCameraIDs, HOCapSequence
22
+ from simplecv.ops.triangulate import batch_triangulate, projectN3
23
+ from simplecv.ops.tsdf_depth_fuser import Open3DFuser
24
+ from simplecv.video_io import MultiVideoReader
25
+
26
+ from annotation_example.gradio_ui.mv_sam_callbacks import (
27
+ KeypointsContainer,
28
+ RerunLogPaths,
29
+ get_recording,
30
+ update_keypoints,
31
+ )
32
+
33
+ if gr.NO_RELOAD:
34
+ VIDEO_SAM_PREDICTOR: SAM2VideoPredictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-tiny")
35
+ IMG_SAM_PREDICTOR: SAM2ImagePredictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-tiny")
36
+
37
+
38
+ def create_blueprint(exo_video_log_paths: list[Path], num_videos_to_log: Literal[4, 8] = 8) -> rrb.Blueprint:
39
+ active_tab: int = 0 # 0 for video, 1 for images
40
+ main_view = rrb.Vertical(
41
+ contents=[
42
+ rrb.Spatial3DView(
43
+ origin="/",
44
+ ),
45
+ # take the first 4 video files
46
+ rrb.Horizontal(
47
+ contents=[
48
+ rrb.Tabs(
49
+ rrb.Spatial2DView(origin=f"{video_log_path.parent}"),
50
+ rrb.Spatial2DView(
51
+ origin=f"{video_log_path}".replace("video", "depth"),
52
+ ),
53
+ active_tab=active_tab,
54
+ )
55
+ for video_log_path in exo_video_log_paths[:4]
56
+ ]
57
+ ),
58
+ ],
59
+ row_shares=[3, 1],
60
+ )
61
+ additional_views = rrb.Vertical(
62
+ contents=[
63
+ rrb.Tabs(
64
+ rrb.Spatial2DView(origin=f"{video_log_path.parent}"),
65
+ rrb.Spatial2DView(origin=f"{video_log_path}".replace("video", "depth")),
66
+ active_tab=active_tab,
67
+ )
68
+ for video_log_path in exo_video_log_paths[4:]
69
+ ]
70
+ )
71
+ # do the last 4 videos
72
+ contents = [main_view]
73
+ if num_videos_to_log == 8:
74
+ contents.append(additional_views)
75
+
76
+ blueprint = rrb.Blueprint(
77
+ rrb.Horizontal(
78
+ contents=contents,
79
+ column_shares=[4, 1],
80
+ ),
81
+ collapse_panels=True,
82
+ )
83
+ return blueprint
84
+
85
+
86
+ def log_pinhole_rec(
87
+ rec: rr.RecordingStream,
88
+ camera: PinholeParameters,
89
+ cam_log_path: Path,
90
+ image_plane_distance: float = 0.5,
91
+ static: bool = False,
92
+ ) -> None:
93
+ """
94
+ Logs the pinhole camera parameters and transformation data.
95
+
96
+ Parameters:
97
+ camera (PinholeParameters): The pinhole camera parameters including intrinsics and extrinsics.
98
+ cam_log_path (Path): The path where the camera log will be saved.
99
+ image_plane_distance (float, optional): The distance of the image plane from the camera. Defaults to 0.5.
100
+ static (bool, optional): If True, the log data will be marked as static. Defaults to False.
101
+
102
+ Returns:
103
+ None
104
+ """
105
+ # camera intrinsics
106
+ rec.log(
107
+ f"{cam_log_path}/pinhole",
108
+ rr.Pinhole(
109
+ image_from_camera=camera.intrinsics.k_matrix,
110
+ height=camera.intrinsics.height,
111
+ width=camera.intrinsics.width,
112
+ camera_xyz=getattr(
113
+ rr.ViewCoordinates,
114
+ camera.intrinsics.camera_conventions,
115
+ ),
116
+ image_plane_distance=image_plane_distance,
117
+ ),
118
+ static=static,
119
+ )
120
+ # camera extrinsics
121
+ rec.log(
122
+ f"{cam_log_path}",
123
+ rr.Transform3D(
124
+ translation=camera.extrinsics.cam_t_world,
125
+ mat3x3=camera.extrinsics.cam_R_world,
126
+ from_parent=True,
127
+ ),
128
+ static=static,
129
+ )
130
+
131
+
132
+ def log_video_rec(
133
+ rec: rr.RecordingStream,
134
+ video_path: Path,
135
+ video_log_path: Path,
136
+ timeline: str = "video_time",
137
+ ) -> Int[np.ndarray, "num_frames"]:
138
+ """
139
+ Logs a video asset and its frame timestamps.
140
+
141
+ Parameters:
142
+ video_path (Path): The path to the video file.
143
+ video_log_path (Path): The path where the video log will be saved.
144
+
145
+ Returns:
146
+ None
147
+ """
148
+ # Log video asset which is referred to by frame references.
149
+ video_asset = rr.AssetVideo(path=video_path)
150
+ rec.log(str(video_log_path), video_asset, static=True)
151
+
152
+ # Send automatically determined video frame timestamps.
153
+ frame_timestamps_ns: Int[np.ndarray, "num_frames"] = ( # noqa: UP037
154
+ video_asset.read_frame_timestamps_ns()
155
+ )
156
+ rec.send_columns(
157
+ f"{video_log_path}",
158
+ # Note timeline values don't have to be the same as the video timestamps.
159
+ indexes=[rr.TimeNanosColumn(timeline, frame_timestamps_ns)],
160
+ columns=rr.VideoFrameReference.columns_nanoseconds(frame_timestamps_ns),
161
+ )
162
+ return frame_timestamps_ns
163
+
164
+
165
+ def rescale_img(img_hw3: UInt8[np.ndarray, "h w 3"], max_dim: int) -> UInt8[np.ndarray, "... 3"]:
166
+ # resize the image to have a max dim of max_dim
167
+ height, width, _ = img_hw3.shape
168
+ current_dim = max(height, width)
169
+
170
+ # If current dimension is larger than max_dim, calculate scale factor
171
+ if current_dim > max_dim:
172
+ scale_factor = max_dim / current_dim
173
+ new_height = int(height * scale_factor)
174
+ new_width = int(width * scale_factor)
175
+
176
+ # Resize image maintaining aspect ratio
177
+ resized_img = cv2.resize(img_hw3, (new_width, new_height), interpolation=cv2.INTER_AREA)
178
+ return resized_img
179
+
180
+ # Return original image if no resize needed
181
+ return img_hw3
182
+
183
+
184
+ @no_type_check
185
+ def reset_keypoints(
186
+ active_recording_id: uuid.UUID, mv_keypoint_dict: dict[str, KeypointsContainer], log_paths: RerunLogPaths
187
+ ):
188
+ yield from _reset_keypoints(
189
+ active_recording_id=active_recording_id,
190
+ mv_keypoint_dict=mv_keypoint_dict,
191
+ log_paths=log_paths,
192
+ )
193
+
194
+
195
+ def _reset_keypoints(
196
+ active_recording_id: uuid.UUID, mv_keypoint_dict: dict[str, KeypointsContainer], log_paths: RerunLogPaths
197
+ ):
198
+ # Now we can produce a valid keypoint.
199
+ rec: rr.RecordingStream = get_recording(active_recording_id)
200
+ stream: rr.BinaryStream = rec.binary_stream()
201
+
202
+ mv_keypoint_dict: dict[str, KeypointsContainer] = {
203
+ cam_name: KeypointsContainer.empty() for cam_name in mv_keypoint_dict
204
+ }
205
+
206
+ rec.set_time_nanos(log_paths["timeline_name"], nanos=0)
207
+ # Log include points if any exist
208
+ for cam_log_path in log_paths["cam_log_path_list"]:
209
+ pinhole_path: Path = cam_log_path / "pinhole"
210
+ print(pinhole_path)
211
+ rec.log(
212
+ f"{pinhole_path}/video/include",
213
+ rr.Clear(recursive=True),
214
+ )
215
+ rec.log(
216
+ f"{pinhole_path}/video/exclude",
217
+ rr.Clear(recursive=True),
218
+ )
219
+ rec.log(
220
+ f"{pinhole_path}/video/bbox",
221
+ rr.Clear(recursive=True),
222
+ )
223
+ rec.log(
224
+ f"{pinhole_path}/video/bbox_center",
225
+ rr.Clear(recursive=True),
226
+ )
227
+ rec.log(
228
+ f"{pinhole_path}/segmentation",
229
+ rr.Clear(recursive=True),
230
+ )
231
+ rec.log(
232
+ f"{pinhole_path}/depth",
233
+ rr.Clear(recursive=True),
234
+ )
235
+
236
+ rec.log(
237
+ f"{log_paths['parent_log_path']}/triangulated",
238
+ rr.Clear(recursive=True),
239
+ )
240
+
241
+ # Ensure we consume everything from the recording.
242
+ stream.flush()
243
+ yield stream.read(), mv_keypoint_dict, {}
244
+
245
+
246
+ @no_type_check
247
+ def get_initial_mask(
248
+ recording_id: uuid.UUID,
249
+ inference_state: dict,
250
+ mv_keypoints_dict: dict[str, KeypointsContainer],
251
+ log_paths: RerunLogPaths,
252
+ rgb_list: list[UInt8[np.ndarray, "h w 3"]],
253
+ keypoint_centers_dict: dict[str, Float32[np.ndarray, "3"]],
254
+ ):
255
+ yield from _get_initial_mask(
256
+ recording_id=recording_id,
257
+ inference_state=inference_state,
258
+ mv_keypoints_dict=mv_keypoints_dict,
259
+ log_paths=log_paths,
260
+ rgb_list=rgb_list,
261
+ keypoint_centers_dict=keypoint_centers_dict,
262
+ )
263
+
264
+
265
+ def _get_initial_mask(
266
+ recording_id: uuid.UUID,
267
+ inference_state: dict,
268
+ mv_keypoints_dict: dict[str, KeypointsContainer],
269
+ log_paths: RerunLogPaths,
270
+ rgb_list: list[UInt8[np.ndarray, "h w 3"]],
271
+ keypoint_centers_dict: dict[str, Float32[np.ndarray, "3"]],
272
+ ):
273
+ rec = get_recording(recording_id)
274
+ stream = rec.binary_stream()
275
+
276
+ rec.set_time_nanos(log_paths["timeline_name"], nanos=0)
277
+
278
+ for (cam_name, keypoint_container), rgb in zip(mv_keypoints_dict.items(), rgb_list, strict=True):
279
+ IMG_SAM_PREDICTOR.set_image(rgb)
280
+ pinhole_log_path: Path = log_paths["parent_log_path"] / cam_name / "pinhole"
281
+ points: Float32[np.ndarray, "num_points 2"] = np.vstack(
282
+ [keypoint_container.include_points, keypoint_container.exclude_points]
283
+ ).astype(np.float32)
284
+ if points.shape[0] == 0:
285
+ IMG_SAM_PREDICTOR.reset_predictor()
286
+ rec.log(
287
+ "logs",
288
+ rr.TextLog("No points selected, skipping segmentation.", level="info"),
289
+ )
290
+ else:
291
+ # Create labels array: 1 for include points, 0 for exclude points
292
+ labels: Int[np.ndarray, "num_points"] = np.ones(len(keypoint_container.include_points), dtype=np.int32) # noqa: UP037
293
+ if len(keypoint_container.exclude_points) > 0:
294
+ labels = np.concatenate([labels, np.zeros(len(keypoint_container.exclude_points), dtype=np.int32)])
295
+
296
+ with torch.inference_mode():
297
+ masks, scores, _ = IMG_SAM_PREDICTOR.predict(
298
+ point_coords=points,
299
+ point_labels=labels,
300
+ multimask_output=False,
301
+ )
302
+ masks: Bool[np.ndarray, "1 h w"] = masks > 0.0
303
+
304
+ rec.log(
305
+ f"{pinhole_log_path}/segmentation",
306
+ rr.SegmentationImage(masks[0].astype(np.uint8)),
307
+ )
308
+ # Convert the mask to a bounding box
309
+ if masks[0].any():
310
+ y_min, y_max = np.where(masks[0].any(axis=1))[0][[0, -1]]
311
+ x_min, x_max = np.where(masks[0].any(axis=0))[0][[0, -1]]
312
+ bbox = np.array([x_min, y_min, x_max, y_max], dtype=np.float32)
313
+ rec.log(
314
+ f"{pinhole_log_path}/video/bbox",
315
+ rr.Boxes2D(array=bbox, array_format=rr.Box2DFormat.XYXY, colors=(0, 0, 255)),
316
+ )
317
+
318
+ # Calculate the center of the bounding box
319
+ center_xyc: Float32[np.ndarray, "3"] = np.array( # noqa: UP037
320
+ [(x_min + x_max) / 2, (y_min + y_max) / 2, 1], dtype=np.float32
321
+ )
322
+ rec.log(
323
+ f"{pinhole_log_path}/video/bbox_center",
324
+ rr.Points2D(positions=(center_xyc[0], center_xyc[1]), colors=(0, 0, 255), radii=5),
325
+ )
326
+ keypoint_centers_dict[cam_name] = center_xyc
327
+ IMG_SAM_PREDICTOR.reset_predictor()
328
+
329
+ yield stream.read(), keypoint_centers_dict
330
+
331
+
332
+ @no_type_check
333
+ def triangulate_centers(
334
+ recording_id: uuid.UUID,
335
+ center_xyc_dict: dict[str, Float32[np.ndarray, "3"]],
336
+ exo_cam_list: list[PinholeParameters],
337
+ log_paths: RerunLogPaths,
338
+ rgb_list: list[UInt8[np.ndarray, "h w 3"]],
339
+ ):
340
+ yield from _triangulate_centers(
341
+ recording_id=recording_id,
342
+ center_xyc_dict=center_xyc_dict,
343
+ exo_cam_list=exo_cam_list,
344
+ log_paths=log_paths,
345
+ rgb_list=rgb_list,
346
+ )
347
+
348
+
349
+ def _triangulate_centers(
350
+ recording_id: uuid.UUID,
351
+ center_xyc_dict: dict[str, Float32[np.ndarray, "3"]],
352
+ exo_cam_list: list[PinholeParameters],
353
+ log_paths: RerunLogPaths,
354
+ rgb_list: list[UInt8[np.ndarray, "h w 3"]],
355
+ ):
356
+ rec = get_recording(recording_id)
357
+ stream = rec.binary_stream()
358
+
359
+ masks_list: list[UInt8[np.ndarray, "h w"]] = []
360
+
361
+ rec.set_time_nanos(log_paths["timeline_name"], nanos=0)
362
+ if len(center_xyc_dict) >= 2:
363
+ centers_xyc: Float32[np.ndarray, "num_views 3"] = np.stack(
364
+ [center_xyc for center_xyc in center_xyc_dict.values() if center_xyc is not None], axis=0
365
+ ).astype(np.float32)
366
+ centers_xyc = rearrange(centers_xyc, "num_views xyc -> num_views 1 xyc")
367
+ proj_matrices: list[Float32[np.ndarray, "3 4"]] = [
368
+ exo_cam.projection_matrix.astype(np.float32) for exo_cam in exo_cam_list
369
+ ]
370
+ proj_matrices: Float32[np.ndarray, "num_views 3 4"] = np.stack(proj_matrices, axis=0).astype(np.float32)
371
+
372
+ proj_matrices_filtered: list[Float32[np.ndarray, "3 4"]] = [
373
+ exo_cam.projection_matrix.astype(np.float32) for exo_cam in exo_cam_list if exo_cam.name in center_xyc_dict
374
+ ]
375
+ proj_matrices_filtered: Float32[np.ndarray, "num_views 3 4"] = np.stack(proj_matrices_filtered, axis=0).astype(
376
+ np.float32
377
+ )
378
+ xyzc: Float[np.ndarray, "n_points 4"] = batch_triangulate(
379
+ keypoints_2d=centers_xyc, projection_matrices=proj_matrices_filtered
380
+ )
381
+ rec.log(
382
+ f"{log_paths['parent_log_path']}/triangulated", rr.Points3D(xyzc[:, 0:3], colors=(0, 0, 255), radii=0.1)
383
+ )
384
+
385
+ projected_xyc = projectN3(
386
+ xyzc,
387
+ proj_matrices,
388
+ )
389
+
390
+ for rgb, cam_log_path, xyc in zip(rgb_list, log_paths["cam_log_path_list"], projected_xyc, strict=True):
391
+ pinhole_log_path: Path = cam_log_path / "pinhole"
392
+ xy = xyc[:, 0:2]
393
+ rec.log(
394
+ f"{pinhole_log_path}/video/bbox_center",
395
+ rr.Points2D(positions=xy, colors=(0, 0, 255), radii=5),
396
+ )
397
+ IMG_SAM_PREDICTOR.set_image(rgb)
398
+ labels: Int[np.ndarray, "num_points"] = np.ones(len(xyc), dtype=np.int32) # noqa: UP037
399
+ with torch.inference_mode():
400
+ masks, scores, _ = IMG_SAM_PREDICTOR.predict(
401
+ point_coords=xy,
402
+ point_labels=labels,
403
+ multimask_output=False,
404
+ )
405
+ masks: Bool[np.ndarray, "1 h w"] = masks > 0.0
406
+
407
+ mask = masks[0].astype(np.uint8)
408
+ masks_list.append(mask)
409
+ rec.log(
410
+ f"{pinhole_log_path}/segmentation",
411
+ rr.SegmentationImage(mask),
412
+ )
413
+ if mask.any():
414
+ y_min, y_max = np.where(masks[0].any(axis=1))[0][[0, -1]]
415
+ x_min, x_max = np.where(masks[0].any(axis=0))[0][[0, -1]]
416
+ bbox = np.array([x_min, y_min, x_max, y_max], dtype=np.float32)
417
+ rec.log(
418
+ f"{pinhole_log_path}/video/bbox",
419
+ rr.Boxes2D(array=bbox, array_format=rr.Box2DFormat.XYXY, colors=(0, 0, 255)),
420
+ )
421
+
422
+ # Calculate the center of the bounding box
423
+ center_xyc: Float32[np.ndarray, "3"] = np.array( # noqa: UP037
424
+ [(x_min + x_max) / 2, (y_min + y_max) / 2, 1], dtype=np.float32
425
+ )
426
+ rec.log(
427
+ f"{pinhole_log_path}/video/bbox_center",
428
+ rr.Points2D(positions=(center_xyc[0], center_xyc[1]), colors=(0, 0, 255), radii=5),
429
+ )
430
+ IMG_SAM_PREDICTOR.reset_predictor()
431
+
432
+ else:
433
+ rec.log(
434
+ "logs",
435
+ rr.TextLog("No points selected, skipping segmentation.", level="info"),
436
+ )
437
+ gr.Info("Not enough points to triangulate.")
438
+ yield stream.read(), masks_list
439
+
440
+
441
+ @no_type_check
442
+ def log_dataset(dataset_name: Literal["hocap", "assembly101"]):
443
+ yield from _log_dataset(dataset_name)
444
+
445
+
446
+ def _log_dataset(dataset_name: Literal["hocap", "assembly101"]):
447
+ recording_id: uuid.UUID = uuid.uuid4()
448
+ rec: rr.RecordingStream = get_recording(recording_id)
449
+ stream: rr.BinaryStream = rec.binary_stream()
450
+
451
+ match dataset_name:
452
+ case "hocap":
453
+ sequence: HOCapSequence = HOCapSequence(
454
+ data_path=Path("data/hocap/sample"),
455
+ sequence_name="20231024_180733",
456
+ subject_id="8",
457
+ load_labels=False,
458
+ )
459
+ case "assembly101":
460
+ # raise NotImplementedError("Assembly101 is not implemented yet.")
461
+ sequence: Assembely101Sequence = Assembely101Sequence(
462
+ data_path=Path("data/assembly101-sample"),
463
+ sequence_name="nusar-2021_action_both_9015-b05b_9015_user_id_2021-02-02_161800",
464
+ subject_id=None,
465
+ load_labels=False,
466
+ )
467
+ case _:
468
+ assert_never(dataset_name)
469
+
470
+ parent_log_path: Path = Path("world")
471
+ timeline_name: str = "frame_idx"
472
+
473
+ images_to_log: int = 8
474
+
475
+ exo_video_readers: MultiVideoReader = sequence.exo_video_readers
476
+ # exo_video_files: list[Path] = exo_video_readers.video_paths[0:images_to_log]
477
+ exo_cam_log_paths: list[Path] = [parent_log_path / exo_cam.name for exo_cam in sequence.exo_cam_list][
478
+ 0:images_to_log
479
+ ]
480
+ exo_video_log_paths: list[Path] = [cam_log_paths / "pinhole" / "video" for cam_log_paths in exo_cam_log_paths][
481
+ 0:images_to_log
482
+ ]
483
+
484
+ initial_blueprint = create_blueprint(exo_video_log_paths, num_videos_to_log=8)
485
+ rec.send_blueprint(initial_blueprint)
486
+ rec.log("/", sequence.world_coordinate_system, static=True)
487
+
488
+ bgr_list: list[UInt8[np.ndarray, "h w 3"]] = exo_video_readers[0][0:images_to_log]
489
+ rgb_list: list[UInt8[np.ndarray, "h w 3"]] = [cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) for bgr in bgr_list]
490
+ # check if depth images exist
491
+ if not sequence.depth_paths:
492
+ depth_paths = None
493
+ else:
494
+ depth_paths: dict[ExoCameraIDs, Path] = sequence.depth_paths[0]
495
+ exo_cam_list: list[PinholeParameters] = sequence.exo_cam_list[0:images_to_log]
496
+
497
+ cam_log_path_list: list[Path] = []
498
+ fuser = Open3DFuser(fusion_resolution=0.01, max_fusion_depth=1.25)
499
+ # log stationary exo cameras and video assets
500
+ for exo_cam in exo_cam_list:
501
+ cam_log_path: Path = parent_log_path / exo_cam.name
502
+ cam_log_path_list.append(cam_log_path)
503
+ image_plane_distance: float = 0.1 if dataset_name == "hocap" else 100.0
504
+ log_pinhole_rec(
505
+ rec=rec,
506
+ camera=exo_cam,
507
+ cam_log_path=cam_log_path,
508
+ image_plane_distance=image_plane_distance,
509
+ static=True,
510
+ )
511
+
512
+ for rgb, cam_log_path, exo_cam in zip(rgb_list, cam_log_path_list, exo_cam_list, strict=True):
513
+ pinhole_log_path: Path = cam_log_path / "pinhole"
514
+ rec.log(f"{pinhole_log_path}/video", rr.Image(rgb, color_model=rr.ColorModel.RGB), static=True)
515
+ # rec.log(f"{pinhole_log_path}/depth", rr.DepthImage(depth_image, meter=1000))
516
+ if depth_paths is not None:
517
+ depth_path: Path = depth_paths[cam_log_path.name]
518
+ depth_image: UInt16[np.ndarray, "480 640"] = cv2.imread(str(depth_path), cv2.IMREAD_ANYDEPTH)
519
+ fuser.fuse_frames(
520
+ depth_image,
521
+ exo_cam.intrinsics.k_matrix,
522
+ exo_cam.extrinsics.cam_T_world,
523
+ rgb,
524
+ )
525
+
526
+ if depth_paths is not None:
527
+ mesh: o3d.geometry.TriangleMesh = fuser.get_mesh()
528
+ mesh.compute_vertex_normals()
529
+
530
+ rec.log(
531
+ f"{parent_log_path}/mesh",
532
+ rr.Mesh3D(
533
+ vertex_positions=mesh.vertices,
534
+ triangle_indices=mesh.triangles,
535
+ vertex_normals=mesh.vertex_normals,
536
+ vertex_colors=mesh.vertex_colors,
537
+ ),
538
+ static=True,
539
+ )
540
+
541
+ pcd: o3d.geometry.PointCloud = mesh.sample_points_poisson_disk(
542
+ number_of_points=20_000,
543
+ )
544
+
545
+ log_paths = RerunLogPaths(
546
+ timeline_name=timeline_name,
547
+ parent_log_path=parent_log_path,
548
+ cam_log_path_list=cam_log_path_list,
549
+ )
550
+
551
+ mv_keypoint_dict: dict[str, KeypointsContainer] = {
552
+ cam_log_path.name: KeypointsContainer.empty() for cam_log_path in cam_log_path_list
553
+ }
554
+
555
+ yield stream.read(), recording_id, log_paths, mv_keypoint_dict, rgb_list, exo_cam_list, pcd
556
+
557
+
558
+ def handle_export(
559
+ exo_cam_list: list[PinholeParameters],
560
+ rgb_list: list[UInt8[np.ndarray, "h w 3"]],
561
+ masks_list: list[UInt8[np.ndarray, "h w"]],
562
+ pointcloud: o3d.geometry.PointCloud,
563
+ ):
564
+ bgr_list: list[UInt8[np.ndarray, "h w 3"]] = [cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) for rgb in rgb_list]
565
+ ns_save_dir: Path = Path("data/nerfstudio-export")
566
+ masks_list: list[UInt8[np.ndarray, "h w"]] | None = masks_list if len(masks_list) > 0 else None
567
+ save_to_nerfstudio(
568
+ ns_save_path=ns_save_dir,
569
+ pinhole_param_list=exo_cam_list,
570
+ bgr_list=bgr_list,
571
+ pointcloud=pointcloud,
572
+ masks_list=masks_list,
573
+ )
574
+
575
+ # Define the path for the output zip file
576
+ zip_output_path = Path("data/nerfstudio-output")
577
+ zip_file_path: str = shutil.make_archive(str(zip_output_path), "zip", str(ns_save_dir))
578
+
579
+ # Return the path to the zip file and switch tabs
580
+ return gr.Tabs(selected=1), zip_file_path
581
+
582
+
583
+ with gr.Blocks() as mv_sam_block:
584
+ mv_keypoint_dict: dict[str, KeypointsContainer] | gr.State = gr.State({})
585
+ inference_state: dict | gr.State = gr.State({})
586
+ rgb_list: list[UInt8[np.ndarray, "h w 3"]] | gr.State = gr.State()
587
+ masks_list: list[UInt8[np.ndarray, "h w"]] | gr.State = gr.State([])
588
+ exo_cam_list: list[PinholeParameters] | gr.State = gr.State([])
589
+ pointcloud: o3d.geometry.PointCloud | gr.State = gr.State()
590
+ centers_xyc_dict: dict[str, Float32[np.ndarray, "3"]] | gr.State = gr.State({})
591
+
592
+ with gr.Row():
593
+ with gr.Tabs() as main_tabs:
594
+ with gr.TabItem("Controls", id=0):
595
+ with gr.Column(scale=1):
596
+ dataset_dropdown = gr.Dropdown(
597
+ label="Dataset",
598
+ choices=["hocap", "assembly101"],
599
+ value="hocap",
600
+ )
601
+ load_dataset_btn = gr.Button("Load Dataset")
602
+
603
+ point_type = gr.Radio(
604
+ label="point type",
605
+ choices=["include", "exclude"],
606
+ value="include",
607
+ scale=1,
608
+ )
609
+ clear_points_btn = gr.Button("Clear Points", scale=1)
610
+ get_initial_mask_btn = gr.Button("Get Initial Mask", scale=1)
611
+ triangulate_btn = gr.Button("Triangulate Center", scale=1)
612
+ export_btn = gr.Button("Export", scale=1)
613
+ with gr.TabItem("Output", id=1):
614
+ gr.Markdown("here you can see the output of the selected video")
615
+ output_zip = gr.File(label="Exported Zip File", file_count="single", type="filepath")
616
+ with gr.Column(scale=4):
617
+ viewer = Rerun(
618
+ streaming=True,
619
+ panel_states={
620
+ "time": "collapsed",
621
+ "blueprint": "hidden",
622
+ "selection": "hidden",
623
+ },
624
+ height=700,
625
+ )
626
+
627
+ # We make a new recording id, and store it in a Gradio's session state.
628
+ recording_id = gr.State()
629
+ log_paths = gr.State({})
630
+
631
+ load_dataset_btn.click(
632
+ fn=log_dataset,
633
+ inputs=[dataset_dropdown],
634
+ outputs=[viewer, recording_id, log_paths, mv_keypoint_dict, rgb_list, exo_cam_list, pointcloud],
635
+ )
636
+
637
+ viewer.selection_change(
638
+ update_keypoints,
639
+ inputs=[
640
+ recording_id,
641
+ point_type,
642
+ mv_keypoint_dict,
643
+ log_paths,
644
+ ],
645
+ outputs=[viewer, mv_keypoint_dict],
646
+ )
647
+
648
+ clear_points_btn.click(
649
+ fn=reset_keypoints,
650
+ inputs=[recording_id, mv_keypoint_dict, log_paths],
651
+ outputs=[viewer, mv_keypoint_dict, centers_xyc_dict],
652
+ )
653
+
654
+ get_initial_mask_btn.click(
655
+ fn=get_initial_mask,
656
+ inputs=[recording_id, inference_state, mv_keypoint_dict, log_paths, rgb_list, centers_xyc_dict],
657
+ outputs=[viewer, centers_xyc_dict],
658
+ )
659
+
660
+ triangulate_btn.click(
661
+ fn=triangulate_centers,
662
+ inputs=[recording_id, centers_xyc_dict, exo_cam_list, log_paths, rgb_list],
663
+ outputs=[viewer, masks_list],
664
+ )
665
+ # TODO export masks + ply + camera poses for use with brush
666
+ export_btn.click(
667
+ fn=handle_export,
668
+ inputs=[exo_cam_list, rgb_list, masks_list, pointcloud],
669
+ outputs=[main_tabs, output_zip],
670
+ )
annotation_example/gradio_ui/mv_sam_callbacks.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import Literal
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ import rerun as rr
9
+ from gradio_rerun.events import SelectionChange
10
+ from typing_extensions import TypedDict
11
+
12
+
13
+ def get_recording(recording_id) -> rr.RecordingStream:
14
+ return rr.RecordingStream(application_id="multiview_sam_annotate", recording_id=recording_id)
15
+
16
+
17
+ class RerunLogPaths(TypedDict):
18
+ timeline_name: str
19
+ parent_log_path: Path
20
+ cam_log_path_list: list[Path]
21
+
22
+
23
+ @dataclass
24
+ class KeypointsContainer:
25
+ """Container for include and exclude keypoints"""
26
+
27
+ include_points: np.ndarray # shape (n,2)
28
+ exclude_points: np.ndarray # shape (m,2)
29
+
30
+ @classmethod
31
+ def empty(cls) -> "KeypointsContainer":
32
+ """Create an empty keypoints container"""
33
+ return cls(include_points=np.zeros((0, 2), dtype=float), exclude_points=np.zeros((0, 2), dtype=float))
34
+
35
+ def add_point(self, point: tuple[float, float], label: Literal["include", "exclude"]) -> None:
36
+ """Add a point with the specified label"""
37
+ point_array = np.array([point], dtype=float)
38
+ if label == "include":
39
+ self.include_points = (
40
+ np.vstack([self.include_points, point_array]) if self.include_points.shape[0] > 0 else point_array
41
+ )
42
+ else:
43
+ self.exclude_points = (
44
+ np.vstack([self.exclude_points, point_array]) if self.exclude_points.shape[0] > 0 else point_array
45
+ )
46
+
47
+ def clear(self) -> None:
48
+ """Clear all points"""
49
+ self.include_points = np.zeros((0, 2), dtype=float)
50
+ self.exclude_points = np.zeros((0, 2), dtype=float)
51
+
52
+
53
+ # In this function, the `request` and `evt` parameters will be automatically injected by Gradio when this event listener is fired.
54
+ #
55
+ # `SelectionChange` is a subclass of `EventData`: https://www.gradio.app/docs/gradio/eventdata
56
+ # `gr.Request`: https://www.gradio.app/main/docs/gradio/request
57
+ def update_keypoints(
58
+ active_recording_id: uuid.UUID,
59
+ point_type: Literal["include", "exclude"],
60
+ mv_keypoint_dict: dict[str, KeypointsContainer],
61
+ log_paths: RerunLogPaths,
62
+ request: gr.Request,
63
+ change: SelectionChange,
64
+ ):
65
+ if active_recording_id == "":
66
+ return
67
+
68
+ evt = change.payload
69
+
70
+ # We can only log a keypoint if the user selected only a single item.
71
+ if len(evt.items) != 1:
72
+ return
73
+ item = evt.items[0]
74
+
75
+ # If the selected item isn't an entity, or we don't have its position, then bail out.
76
+ if item.type != "entity" or item.position is None:
77
+ return
78
+
79
+ # Now we can produce a valid keypoint.
80
+ rec: rr.RecordingStream = get_recording(active_recording_id)
81
+ stream: rr.BinaryStream = rec.binary_stream()
82
+ current_keypoint: tuple[int, int] = item.position[0:2]
83
+
84
+ for cam_name in mv_keypoint_dict:
85
+ if cam_name in item.entity_path:
86
+ # Update the keypoints for the specific camera
87
+ mv_keypoint_dict[cam_name].add_point(current_keypoint, point_type)
88
+ current_keypoint_container: KeypointsContainer = mv_keypoint_dict[cam_name]
89
+
90
+ rec.set_time_nanos(log_paths["timeline_name"], nanos=0)
91
+ # Log include points if any exist
92
+ if current_keypoint_container.include_points.shape[0] > 0:
93
+ rec.log(
94
+ f"{item.entity_path}/include",
95
+ rr.Points2D(current_keypoint_container.include_points, colors=(0, 255, 0), radii=5),
96
+ )
97
+
98
+ # Log exclude points if any exist
99
+ if current_keypoint_container.exclude_points.shape[0] > 0:
100
+ rec.log(
101
+ f"{item.entity_path}/exclude",
102
+ rr.Points2D(current_keypoint_container.exclude_points, colors=(255, 0, 0), radii=5),
103
+ )
104
+
105
+ # # Ensure we consume everything from the recording.
106
+ stream.flush()
107
+ yield stream.read(), mv_keypoint_dict
annotation_example/gradio_ui/sv_sam.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ import spaces # type: ignore
3
+
4
+ IN_SPACES = True
5
+ except ImportError:
6
+ print("Not running on Zero")
7
+ IN_SPACES = False
8
+
9
+
10
+ import tempfile
11
+ import uuid
12
+ from dataclasses import dataclass, fields
13
+ from pathlib import Path
14
+ from typing import Literal, TypedDict
15
+
16
+ import cv2
17
+ import gradio as gr
18
+ import numpy as np
19
+ import rerun as rr
20
+ import rerun.blueprint as rrb
21
+ import torch
22
+ from einops import rearrange
23
+ from gradio_rerun import Rerun
24
+ from gradio_rerun.events import (
25
+ SelectionChange,
26
+ )
27
+ from jaxtyping import Bool, Float, Float32, UInt8
28
+ from monopriors.depth_utils import clip_disparity, depth_edges_mask, depth_to_points
29
+ from monopriors.relative_depth_models.depth_anything_v2 import (
30
+ DepthAnythingV2Predictor,
31
+ RelativeDepthPrediction,
32
+ )
33
+ from sam2.sam2_video_predictor import SAM2VideoPredictor
34
+ from simplecv.video_io import VideoReader
35
+
36
+ from annotation_example.op import create_blueprint
37
+
38
+ if gr.NO_RELOAD:
39
+ VIDEO_SAM_PREDICTOR: SAM2VideoPredictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-tiny")
40
+ DEPTH_PREDICTOR = DepthAnythingV2Predictor(device="cpu", encoder="vits")
41
+ DEPTH_PREDICTOR.set_model_device("cuda")
42
+
43
+
44
+ class RerunLogPaths(TypedDict):
45
+ timeline_name: str
46
+ parent_log_path: Path
47
+ camera_log_path: Path
48
+ pinhole_path: Path
49
+
50
+
51
+ def log_relative_pred_rec(
52
+ rec: rr.RecordingStream,
53
+ parent_log_path: Path,
54
+ relative_pred: RelativeDepthPrediction,
55
+ rgb_hw3: UInt8[np.ndarray, "h w 3"],
56
+ seg_mask_hw: UInt8[np.ndarray, "h w"] | None = None,
57
+ remove_flying_pixels: bool = True,
58
+ jpeg_quality: int = 90,
59
+ depth_edge_threshold: float = 1.1,
60
+ ) -> None:
61
+ cam_log_path: Path = parent_log_path / "camera"
62
+ pinhole_path: Path = cam_log_path / "pinhole"
63
+
64
+ # assume camera is at the origin
65
+ cam_T_world_44: Float[np.ndarray, "4 4"] = np.eye(4)
66
+
67
+ rec.log(
68
+ f"{cam_log_path}",
69
+ rr.Transform3D(
70
+ translation=cam_T_world_44[:3, 3],
71
+ mat3x3=cam_T_world_44[:3, :3],
72
+ from_parent=True,
73
+ ),
74
+ )
75
+ rec.log(
76
+ f"{pinhole_path}",
77
+ rr.Pinhole(
78
+ image_from_camera=relative_pred.K_33,
79
+ width=rgb_hw3.shape[1],
80
+ height=rgb_hw3.shape[0],
81
+ image_plane_distance=1.5,
82
+ camera_xyz=rr.ViewCoordinates.RDF,
83
+ ),
84
+ )
85
+ rec.log(f"{pinhole_path}/image", rr.Image(rgb_hw3).compress(jpeg_quality=jpeg_quality))
86
+
87
+ depth_hw: Float32[np.ndarray, "h w"] = relative_pred.depth
88
+ disparity = relative_pred.disparity
89
+ # removes outliers from disparity (sometimes we can get weirdly large values)
90
+ clipped_disparity: UInt8[np.ndarray, "h w"] = clip_disparity(disparity)
91
+ if remove_flying_pixels:
92
+ edges_mask: Bool[np.ndarray, "h w"] = depth_edges_mask(depth_hw, threshold=depth_edge_threshold)
93
+ rec.log(
94
+ f"{pinhole_path}/edge_mask",
95
+ rr.SegmentationImage(edges_mask.astype(np.uint8)),
96
+ )
97
+ depth_hw: Float32[np.ndarray, "h w"] = depth_hw * ~edges_mask
98
+ clipped_disparity: Float32[np.ndarray, "h w"] = clipped_disparity * ~edges_mask
99
+
100
+ if seg_mask_hw is not None:
101
+ rec.log(
102
+ f"{pinhole_path}/segmentation",
103
+ rr.SegmentationImage(seg_mask_hw),
104
+ )
105
+ depth_hw: Float32[np.ndarray, "h w"] = depth_hw # * seg_mask_hw
106
+ clipped_disparity: Float32[np.ndarray, "h w"] = clipped_disparity # * seg_mask_hw
107
+
108
+ rec.log(f"{pinhole_path}/depth", rr.DepthImage(depth_hw))
109
+
110
+ # log to cam_log_path to avoid backprojecting disparity
111
+ rec.log(f"{cam_log_path}/disparity", rr.DepthImage(clipped_disparity))
112
+
113
+ depth_1hw: Float32[np.ndarray, "1 h w"] = rearrange(depth_hw, "h w -> 1 h w")
114
+ pts_3d: Float32[np.ndarray, "h w 3"] = depth_to_points(depth_1hw, relative_pred.K_33)
115
+
116
+ colors = rgb_hw3.reshape(-1, 3)
117
+
118
+ # If we have a segmentation mask, make those pixels blue
119
+ if seg_mask_hw is not None:
120
+ # Reshape the mask to match colors shape
121
+ flat_mask = seg_mask_hw.reshape(-1)
122
+
123
+ # Set pixels where mask == 1 to blue (BGR format)
124
+ # Blue: [255, 0, 0] in BGR or [0, 0, 255] in RGB
125
+ colors[flat_mask == 1, :] = [0, 0, 255] # RGB format: Blue
126
+
127
+ rec.log(
128
+ f"{parent_log_path}/point_cloud",
129
+ rr.Points3D(
130
+ positions=pts_3d.reshape(-1, 3),
131
+ colors=colors,
132
+ ),
133
+ )
134
+
135
+
136
+ @dataclass
137
+ class KeypointsContainer:
138
+ """Container for include and exclude keypoints"""
139
+
140
+ include_points: np.ndarray # shape (n,2)
141
+ exclude_points: np.ndarray # shape (m,2)
142
+
143
+ @classmethod
144
+ def empty(cls) -> "KeypointsContainer":
145
+ """Create an empty keypoints container"""
146
+ return cls(include_points=np.zeros((0, 2), dtype=float), exclude_points=np.zeros((0, 2), dtype=float))
147
+
148
+ def add_point(self, point: tuple[float, float], label: Literal["include", "exclude"]) -> None:
149
+ """Add a point with the specified label"""
150
+ point_array = np.array([point], dtype=float)
151
+ if label == "include":
152
+ self.include_points = (
153
+ np.vstack([self.include_points, point_array]) if self.include_points.shape[0] > 0 else point_array
154
+ )
155
+ else:
156
+ self.exclude_points = (
157
+ np.vstack([self.exclude_points, point_array]) if self.exclude_points.shape[0] > 0 else point_array
158
+ )
159
+
160
+ def clear(self) -> None:
161
+ """Clear all points"""
162
+ self.include_points = np.zeros((0, 2), dtype=float)
163
+ self.exclude_points = np.zeros((0, 2), dtype=float)
164
+
165
+
166
+ # In this function, the `request` and `evt` parameters will be automatically injected by Gradio when this event listener is fired.
167
+ #
168
+ # `SelectionChange` is a subclass of `EventData`: https://www.gradio.app/docs/gradio/eventdata
169
+ # `gr.Request`: https://www.gradio.app/main/docs/gradio/request
170
+ def single_view_update_keypoints(
171
+ active_recording_id: uuid.UUID,
172
+ point_type: Literal["include", "exclude"],
173
+ keypoints_container: KeypointsContainer,
174
+ log_paths: RerunLogPaths,
175
+ request: gr.Request,
176
+ change: SelectionChange,
177
+ ):
178
+ evt = change.payload
179
+
180
+ # We can only log a keypoint if the user selected only a single item.
181
+ if len(evt.items) != 1:
182
+ return
183
+ item = evt.items[0]
184
+
185
+ # If the selected item isn't an entity, or we don't have its position, then bail out.
186
+ if item.type != "entity" or item.position is None:
187
+ return
188
+
189
+ # Now we can produce a valid keypoint.
190
+ rec: rr.RecordingStream = get_recording(active_recording_id)
191
+ stream: rr.BinaryStream = rec.binary_stream()
192
+ current_keypoint: tuple[int, int] = item.position[0:2]
193
+ keypoints_container.add_point(current_keypoint, point_type)
194
+
195
+ rec.set_time_sequence(log_paths["timeline_name"], sequence=0)
196
+ # Log include points if any exist
197
+ if keypoints_container.include_points.shape[0] > 0:
198
+ rec.log(
199
+ f"{item.entity_path}/include", rr.Points2D(keypoints_container.include_points, colors=(0, 255, 0), radii=5)
200
+ )
201
+
202
+ # Log exclude points if any exist
203
+ if keypoints_container.exclude_points.shape[0] > 0:
204
+ rec.log(
205
+ f"{item.entity_path}/exclude",
206
+ rr.Points2D(keypoints_container.exclude_points, colors=(255, 0, 0), radii=5),
207
+ )
208
+
209
+ # Ensure we consume everything from the recording.
210
+ stream.flush()
211
+ yield stream.read(), keypoints_container
212
+
213
+
214
+ def get_recording(recording_id) -> rr.RecordingStream:
215
+ return rr.RecordingStream(application_id="Single View Annotation", recording_id=recording_id)
216
+
217
+
218
+ def rescale_img(img_hw3: UInt8[np.ndarray, "h w 3"], max_dim: int) -> UInt8[np.ndarray, "... 3"]:
219
+ # resize the image to have a max dim of max_dim
220
+ height, width, _ = img_hw3.shape
221
+ current_dim = max(height, width)
222
+
223
+ # If current dimension is larger than max_dim, calculate scale factor
224
+ if current_dim > max_dim:
225
+ scale_factor = max_dim / current_dim
226
+ new_height = int(height * scale_factor)
227
+ new_width = int(width * scale_factor)
228
+
229
+ # Resize image maintaining aspect ratio
230
+ resized_img: UInt8[np.ndarray, "... 3"] = cv2.resize(
231
+ img_hw3, (new_width, new_height), interpolation=cv2.INTER_AREA
232
+ )
233
+ return resized_img
234
+
235
+ # Return original image if no resize needed
236
+ return img_hw3
237
+
238
+
239
+ # Allow using keyword args in gradio to avoid mixing up the order of inputs
240
+ # a bit of an antipattern that is requied to make things work with beartype + keyword args
241
+ @dataclass
242
+ class PreprocessVideoComponents:
243
+ video_file: gr.Video
244
+
245
+ def to_list(self) -> list:
246
+ return [getattr(self, f.name) for f in fields(self)]
247
+
248
+
249
+ @dataclass
250
+ class PreprocessVideoValues:
251
+ video_file: str
252
+
253
+
254
+ def preprocess_video(
255
+ *input_params,
256
+ ):
257
+ yield from _preprocess_video(*input_params)
258
+
259
+
260
+ def _preprocess_video(
261
+ *input_params,
262
+ progress=gr.Progress(track_tqdm=True), # noqa B008
263
+ ):
264
+ input_values: PreprocessVideoValues = PreprocessVideoValues(*input_params)
265
+ # create a new recording id, and store it in a Gradio's session state.
266
+ recording_id: uuid.UUID = uuid.uuid4()
267
+ rec: rr.RecordingStream = get_recording(recording_id)
268
+ stream: rr.BinaryStream = rec.binary_stream()
269
+
270
+ log_paths = RerunLogPaths(
271
+ timeline_name="frame_idx",
272
+ parent_log_path=Path("world"),
273
+ camera_log_path=Path("world") / "camera",
274
+ pinhole_path=Path("world") / "camera" / "pinhole",
275
+ )
276
+
277
+ video_path: Path = Path(input_values.video_file)
278
+
279
+ initial_blueprint = rrb.Blueprint(
280
+ rrb.Horizontal(
281
+ rrb.Spatial2DView(origin=f"{log_paths['pinhole_path']}"),
282
+ ),
283
+ collapse_panels=True,
284
+ )
285
+
286
+ rec.send_blueprint(initial_blueprint)
287
+
288
+ video_reader: VideoReader = VideoReader(video_path)
289
+ tmp_frames_dir: str = tempfile.mkdtemp()
290
+
291
+ target_fps: int = 10
292
+ frame_interval: int = int(video_reader.fps // target_fps)
293
+ max_frames: int = 100
294
+ total_saved_frames: int = 0
295
+ max_size: int = 640
296
+
297
+ progress(0, desc="Reading video frames")
298
+ for idx, bgr in enumerate(video_reader):
299
+ if idx % frame_interval == 0:
300
+ if total_saved_frames >= max_frames:
301
+ break
302
+ bgr: np.ndarray = rescale_img(bgr, max_size)
303
+ # 3. Save frames to temporary directory
304
+ cv2.imwrite(f"{tmp_frames_dir}/{idx:05d}.jpg", bgr)
305
+ total_saved_frames += 1
306
+
307
+ first_frame_path: Path = Path(tmp_frames_dir) / "00000.jpg"
308
+ first_bgr: np.ndarray = cv2.imread(str(first_frame_path))
309
+
310
+ progress(0.5, desc="Initializing SAM")
311
+ with torch.inference_mode():
312
+ inference_state = VIDEO_SAM_PREDICTOR.init_state(video_path=tmp_frames_dir)
313
+ VIDEO_SAM_PREDICTOR.reset_state(inference_state)
314
+ print(type(inference_state))
315
+
316
+ rec.set_time_sequence(log_paths["timeline_name"], sequence=0)
317
+ rec.log(
318
+ f"{log_paths['pinhole_path']}/image",
319
+ rr.Image(first_bgr, color_model=rr.ColorModel.BGR).compress(jpeg_quality=90),
320
+ )
321
+
322
+ # Ensure we consume everything from the recording.
323
+ stream.flush()
324
+
325
+ yield gr.Accordion(open=False), stream.read(), inference_state, Path(tmp_frames_dir), recording_id, log_paths
326
+
327
+
328
+ def reset_keypoints(active_recording_id: uuid.UUID, keypoints_container: KeypointsContainer, log_paths: RerunLogPaths):
329
+ # Now we can produce a valid keypoint.
330
+ rec: rr.RecordingStream = get_recording(active_recording_id)
331
+ stream: rr.BinaryStream = rec.binary_stream()
332
+
333
+ keypoints_container.clear()
334
+
335
+ rec.set_time_sequence(log_paths["timeline_name"], sequence=0)
336
+ rec.log(
337
+ f"{log_paths['pinhole_path']}/image/include",
338
+ rr.Clear(recursive=True),
339
+ )
340
+ rec.log(
341
+ f"{log_paths['pinhole_path']}/image/exclude",
342
+ rr.Clear(recursive=True),
343
+ )
344
+ rec.log(
345
+ f"{log_paths['pinhole_path']}/segmentation",
346
+ rr.Clear(recursive=True),
347
+ )
348
+ rec.log(
349
+ f"{log_paths['pinhole_path']}/depth",
350
+ rr.Clear(recursive=True),
351
+ )
352
+
353
+ # Ensure we consume everything from the recording.
354
+ stream.flush()
355
+ yield stream.read(), keypoints_container
356
+
357
+
358
+ def get_initial_mask(
359
+ recording_id: uuid.UUID,
360
+ inference_state: dict,
361
+ keypoint_container: KeypointsContainer,
362
+ log_paths: RerunLogPaths,
363
+ ):
364
+ rec = get_recording(recording_id)
365
+ stream = rec.binary_stream()
366
+
367
+ rec.set_time_sequence(log_paths["timeline_name"], 0)
368
+
369
+ points = np.vstack([keypoint_container.include_points, keypoint_container.exclude_points]).astype(np.float32)
370
+ if len(points) == 0:
371
+ raise gr.Error("No points selected. Please add include or exclude points.")
372
+
373
+ # Create labels array: 1 for include points, 0 for exclude points
374
+ labels = np.ones(len(keypoint_container.include_points), dtype=np.int32)
375
+ if len(keypoint_container.exclude_points) > 0:
376
+ labels = np.concatenate([labels, np.zeros(len(keypoint_container.exclude_points), dtype=np.int32)])
377
+
378
+ print(f"Points shape: {points.shape}")
379
+ print(f"Labels shape: {labels.shape}")
380
+ print(labels)
381
+ print(
382
+ f"Include points: {keypoint_container.include_points.shape}, Exclude points: {keypoint_container.exclude_points.shape}"
383
+ )
384
+
385
+ with torch.inference_mode():
386
+ frame_idx: int
387
+ object_ids: list
388
+ masks: Float32[torch.Tensor, "b 3 h w"]
389
+
390
+ frame_idx, object_ids, masks = VIDEO_SAM_PREDICTOR.add_new_points_or_box(
391
+ inference_state=inference_state,
392
+ frame_idx=0,
393
+ obj_id=0,
394
+ points=points,
395
+ labels=labels,
396
+ )
397
+
398
+ masks: Bool[np.ndarray, "1 h w"] = (masks[0] > 0.0).numpy(force=True)
399
+
400
+ rec.log(
401
+ f"{log_paths['pinhole_path']}/segmentation",
402
+ rr.SegmentationImage(masks[0].astype(np.uint8)),
403
+ )
404
+ yield stream.read()
405
+
406
+
407
+ def propagate_mask(
408
+ recording_id: uuid.UUID,
409
+ inference_state: dict,
410
+ keypoint_container: KeypointsContainer,
411
+ frames_dir: Path,
412
+ log_paths: RerunLogPaths,
413
+ ):
414
+ rec = get_recording(recording_id)
415
+ stream = rec.binary_stream()
416
+
417
+ blueprint = create_blueprint(parent_log_path=log_paths["parent_log_path"])
418
+ rec.send_blueprint(blueprint)
419
+
420
+ rec.log(f"{log_paths['parent_log_path']}", rr.ViewCoordinates.RDF)
421
+
422
+ points = np.vstack([keypoint_container.include_points, keypoint_container.exclude_points]).astype(np.float32)
423
+ if len(points) == 0:
424
+ raise gr.Error("No points selected. Please add include or exclude points.")
425
+
426
+ # Create labels array: 1 for include points, 0 for exclude points
427
+ labels = np.ones(len(keypoint_container.include_points), dtype=np.int32)
428
+ if len(keypoint_container.exclude_points) > 0:
429
+ labels = np.concatenate([labels, np.zeros(len(keypoint_container.exclude_points), dtype=np.int32)])
430
+
431
+ frames_paths: list[Path] = sorted(frames_dir.glob("*.jpg"))
432
+
433
+ # remove the keypoints as they're in the way during propagation
434
+ rec.log(
435
+ f"{log_paths['pinhole_path']}/include",
436
+ rr.Clear(recursive=True),
437
+ )
438
+ rec.log(
439
+ f"{log_paths['pinhole_path']}/exclude",
440
+ rr.Clear(recursive=True),
441
+ )
442
+
443
+ with torch.inference_mode():
444
+ frame_idx: int
445
+ object_ids: list
446
+ masks: Float32[torch.Tensor, "b 3 h w"]
447
+
448
+ frame_idx, object_ids, masks = VIDEO_SAM_PREDICTOR.add_new_points_or_box(
449
+ inference_state, frame_idx=0, obj_id=0, points=points, labels=labels
450
+ )
451
+
452
+ # propagate the prompts to get masklets throughout the video
453
+ for frames_path, (frame_idx, object_ids, masks) in zip(
454
+ frames_paths, VIDEO_SAM_PREDICTOR.propagate_in_video(inference_state), strict=True
455
+ ):
456
+ rec.set_time_sequence(log_paths["timeline_name"], frame_idx)
457
+ masks: Bool[np.ndarray, "1 h w"] = (masks[0] > 0.0).numpy(force=True)
458
+ bgr = cv2.imread(str(frames_path))
459
+ rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
460
+ depth_pred: RelativeDepthPrediction = DEPTH_PREDICTOR.__call__(rgb=rgb, K_33=None)
461
+
462
+ log_relative_pred_rec(
463
+ rec=rec,
464
+ parent_log_path=log_paths["parent_log_path"],
465
+ relative_pred=depth_pred,
466
+ rgb_hw3=rgb,
467
+ seg_mask_hw=masks[0].astype(np.uint8),
468
+ remove_flying_pixels=True,
469
+ jpeg_quality=90,
470
+ depth_edge_threshold=0.1,
471
+ )
472
+
473
+ yield stream.read()
474
+
475
+
476
+ with gr.Blocks() as single_view_block:
477
+ keypoints = gr.State(KeypointsContainer.empty())
478
+ inference_state = gr.State({})
479
+ frames_dir = gr.State(Path())
480
+ with gr.Row():
481
+ with gr.Column(scale=1):
482
+ with gr.Accordion("Your video IN", open=True) as video_in_drawer:
483
+ video_in = gr.Video(label="Video IN", format=None)
484
+
485
+ point_type = gr.Radio(
486
+ label="point type",
487
+ choices=["include", "exclude"],
488
+ value="include",
489
+ scale=1,
490
+ )
491
+ clear_points_btn = gr.Button("Clear Points", scale=1)
492
+ get_initial_mask_btn = gr.Button("Get Initial Mask", scale=1)
493
+ propagate_mask_btn = gr.Button("Propagate Mask", scale=1)
494
+ stop_propagation_btn = gr.Button("Stop Propagation", scale=1)
495
+
496
+ with gr.Column(scale=4):
497
+ viewer = Rerun(
498
+ streaming=True,
499
+ panel_states={
500
+ "time": "collapsed",
501
+ "blueprint": "hidden",
502
+ "selection": "hidden",
503
+ },
504
+ height=700,
505
+ )
506
+
507
+ # We make a new recording id, and store it in a Gradio's session state.
508
+ recording_id = gr.State()
509
+ log_paths = gr.State({})
510
+
511
+ input_components = PreprocessVideoComponents(
512
+ video_file=video_in,
513
+ )
514
+
515
+ # triggered on video upload
516
+ video_in.upload(
517
+ fn=preprocess_video,
518
+ inputs=input_components.to_list(),
519
+ outputs=[video_in_drawer, viewer, inference_state, frames_dir, recording_id, log_paths],
520
+ )
521
+
522
+ viewer.selection_change(
523
+ single_view_update_keypoints,
524
+ inputs=[
525
+ recording_id,
526
+ point_type,
527
+ keypoints,
528
+ log_paths,
529
+ ],
530
+ outputs=[viewer, keypoints],
531
+ )
532
+
533
+ clear_points_btn.click(
534
+ fn=reset_keypoints,
535
+ inputs=[recording_id, keypoints, log_paths],
536
+ outputs=[viewer, keypoints],
537
+ )
538
+
539
+ get_initial_mask_btn.click(
540
+ fn=get_initial_mask,
541
+ inputs=[recording_id, inference_state, keypoints, log_paths],
542
+ outputs=[viewer],
543
+ )
544
+
545
+ propagate_event = propagate_mask_btn.click(
546
+ fn=propagate_mask,
547
+ inputs=[recording_id, inference_state, keypoints, frames_dir, log_paths],
548
+ outputs=[viewer],
549
+ )
550
+
551
+ stop_propagation_btn.click(
552
+ fn=lambda: None,
553
+ inputs=[],
554
+ outputs=[],
555
+ cancels=[propagate_event],
556
+ )
annotation_example/op.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import numpy as np
4
+ import rerun as rr
5
+ import rerun.blueprint as rrb
6
+ from einops import rearrange
7
+ from jaxtyping import Bool, Float32, Float64, UInt8
8
+ from monopriors.depth_utils import clip_disparity, depth_edges_mask, depth_to_points
9
+ from monopriors.relative_depth_models.depth_anything_v2 import (
10
+ RelativeDepthPrediction,
11
+ )
12
+
13
+
14
+ def log_relative_pred(
15
+ parent_log_path: Path,
16
+ relative_pred: RelativeDepthPrediction,
17
+ rgb_hw3: UInt8[np.ndarray, "h w 3"],
18
+ seg_mask_hw: UInt8[np.ndarray, "h w"] | None = None,
19
+ remove_flying_pixels: bool = True,
20
+ jpeg_quality: int = 90,
21
+ depth_edge_threshold: float = 1.1,
22
+ ) -> None:
23
+ cam_log_path: Path = parent_log_path / "camera"
24
+ pinhole_path: Path = cam_log_path / "pinhole"
25
+
26
+ # assume camera is at the origin
27
+ cam_T_world_44: Float64[np.ndarray, "4 4"] = np.eye(4)
28
+
29
+ rr.log(
30
+ f"{cam_log_path}",
31
+ rr.Transform3D(
32
+ translation=cam_T_world_44[:3, 3],
33
+ mat3x3=cam_T_world_44[:3, :3],
34
+ from_parent=True,
35
+ ),
36
+ )
37
+ rr.log(
38
+ f"{pinhole_path}",
39
+ rr.Pinhole(
40
+ image_from_camera=relative_pred.K_33,
41
+ width=rgb_hw3.shape[1],
42
+ height=rgb_hw3.shape[0],
43
+ camera_xyz=rr.ViewCoordinates.RDF,
44
+ ),
45
+ )
46
+ rr.log(f"{pinhole_path}/image", rr.Image(rgb_hw3).compress(jpeg_quality=jpeg_quality))
47
+
48
+ depth_hw: Float32[np.ndarray, "h w"] = relative_pred.depth
49
+ if remove_flying_pixels:
50
+ edges_mask: Bool[np.ndarray, "h w"] = depth_edges_mask(depth_hw, threshold=depth_edge_threshold)
51
+ rr.log(
52
+ f"{pinhole_path}/edge_mask",
53
+ rr.SegmentationImage(edges_mask.astype(np.uint8)),
54
+ )
55
+ depth_hw: Float32[np.ndarray, "h w"] = depth_hw * ~edges_mask
56
+
57
+ if seg_mask_hw is not None:
58
+ rr.log(
59
+ f"{pinhole_path}/segmentation",
60
+ rr.SegmentationImage(seg_mask_hw),
61
+ )
62
+ depth_hw: Float32[np.ndarray, "h w"] = depth_hw # * seg_mask_hw
63
+
64
+ rr.log(f"{pinhole_path}/depth", rr.DepthImage(depth_hw))
65
+
66
+ # removes outliers from disparity (sometimes we can get weirdly large values)
67
+ clipped_disparity: UInt8[np.ndarray, "h w"] = clip_disparity(relative_pred.disparity)
68
+
69
+ # log to cam_log_path to avoid backprojecting disparity
70
+ rr.log(f"{cam_log_path}/disparity", rr.DepthImage(clipped_disparity))
71
+
72
+ depth_1hw: Float32[np.ndarray, "1 h w"] = rearrange(depth_hw, "h w -> 1 h w")
73
+ pts_3d: Float32[np.ndarray, "h w 3"] = depth_to_points(depth_1hw, relative_pred.K_33)
74
+
75
+ rr.log(
76
+ f"{parent_log_path}/point_cloud",
77
+ rr.Points3D(
78
+ positions=pts_3d.reshape(-1, 3),
79
+ colors=rgb_hw3.reshape(-1, 3),
80
+ ),
81
+ )
82
+
83
+
84
+ def create_blueprint(parent_log_path: Path) -> rrb.Blueprint:
85
+ cam_log_path: Path = parent_log_path / "camera"
86
+ pinhole_path: Path = cam_log_path / "pinhole"
87
+
88
+ contents: list = [
89
+ rrb.Vertical(
90
+ rrb.Spatial2DView(
91
+ origin=f"{pinhole_path}/image",
92
+ ),
93
+ rrb.Spatial2DView(
94
+ origin=f"{pinhole_path}/segmentation",
95
+ ),
96
+ rrb.Spatial2DView(
97
+ origin=f"{cam_log_path}/disparity",
98
+ ),
99
+ ),
100
+ rrb.Spatial3DView(origin=f"{parent_log_path}"),
101
+ ]
102
+ blueprint = rrb.Blueprint(
103
+ rrb.Horizontal(contents=contents, column_shares=[1, 3]),
104
+ collapse_panels=True,
105
+ )
106
+ return blueprint