|
|
|
|
|
|
|
|
|
|
|
import json |
|
import os |
|
from typing import Dict, List, Optional, Tuple |
|
|
|
import cv2 |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import pycocotools.mask as mask_util |
|
|
|
|
|
def decode_video(video_path: str) -> List[np.ndarray]: |
|
""" |
|
Decode the video and return the RGB frames |
|
""" |
|
video = cv2.VideoCapture(video_path) |
|
video_frames = [] |
|
while video.isOpened(): |
|
ret, frame = video.read() |
|
if ret: |
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
video_frames.append(frame) |
|
else: |
|
break |
|
return video_frames |
|
|
|
|
|
def show_anns(masks, colors: List, borders=True) -> None: |
|
""" |
|
show the annotations |
|
""" |
|
|
|
if len(masks) == 0: |
|
return |
|
|
|
|
|
sorted_annot_and_color = sorted( |
|
zip(masks, colors), key=(lambda x: x[0].sum()), reverse=True |
|
) |
|
H, W = sorted_annot_and_color[0][0].shape[0], sorted_annot_and_color[0][0].shape[1] |
|
|
|
canvas = np.ones((H, W, 4)) |
|
canvas[:, :, 3] = 0 |
|
contour_thickness = max(1, int(min(5, 0.01 * min(H, W)))) |
|
for mask, color in sorted_annot_and_color: |
|
canvas[mask] = np.concatenate([color, [0.55]]) |
|
if borders: |
|
contours, _ = cv2.findContours( |
|
np.array(mask, dtype=np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE |
|
) |
|
cv2.drawContours( |
|
canvas, contours, -1, (0.05, 0.05, 0.05, 1), thickness=contour_thickness |
|
) |
|
|
|
ax = plt.gca() |
|
ax.imshow(canvas) |
|
|
|
|
|
class SAVDataset: |
|
""" |
|
SAVDataset is a class to load the SAV dataset and visualize the annotations. |
|
""" |
|
|
|
def __init__(self, sav_dir, annot_sample_rate=4): |
|
""" |
|
Args: |
|
sav_dir: the directory of the SAV dataset |
|
annot_sample_rate: the sampling rate of the annotations. |
|
The annotations are aligned with the videos at 6 fps. |
|
""" |
|
self.sav_dir = sav_dir |
|
self.annot_sample_rate = annot_sample_rate |
|
self.manual_mask_colors = np.random.random((256, 3)) |
|
self.auto_mask_colors = np.random.random((256, 3)) |
|
|
|
def read_frames(self, mp4_path: str) -> None: |
|
""" |
|
Read the frames and downsample them to align with the annotations. |
|
""" |
|
if not os.path.exists(mp4_path): |
|
print(f"{mp4_path} doesn't exist.") |
|
return None |
|
else: |
|
|
|
frames = decode_video(mp4_path) |
|
print(f"There are {len(frames)} frames decoded from {mp4_path} (24fps).") |
|
|
|
|
|
frames = frames[:: self.annot_sample_rate] |
|
print( |
|
f"Videos are annotated every {self.annot_sample_rate} frames. " |
|
"To align with the annotations, " |
|
f"downsample the video to {len(frames)} frames." |
|
) |
|
return frames |
|
|
|
def get_frames_and_annotations( |
|
self, video_id: str |
|
) -> Tuple[List | None, Dict | None, Dict | None]: |
|
""" |
|
Get the frames and annotations for video. |
|
""" |
|
|
|
mp4_path = os.path.join(self.sav_dir, video_id + ".mp4") |
|
frames = self.read_frames(mp4_path) |
|
if frames is None: |
|
return None, None, None |
|
|
|
|
|
manual_annot_path = os.path.join(self.sav_dir, video_id + "_manual.json") |
|
if not os.path.exists(manual_annot_path): |
|
print(f"{manual_annot_path} doesn't exist. Something might be wrong.") |
|
manual_annot = None |
|
else: |
|
manual_annot = json.load(open(manual_annot_path)) |
|
|
|
|
|
auto_annot_path = os.path.join(self.sav_dir, video_id + "_auto.json") |
|
if not os.path.exists(auto_annot_path): |
|
print(f"{auto_annot_path} doesn't exist.") |
|
auto_annot = None |
|
else: |
|
auto_annot = json.load(open(auto_annot_path)) |
|
|
|
return frames, manual_annot, auto_annot |
|
|
|
def visualize_annotation( |
|
self, |
|
frames: List[np.ndarray], |
|
auto_annot: Optional[Dict], |
|
manual_annot: Optional[Dict], |
|
annotated_frame_id: int, |
|
show_auto=True, |
|
show_manual=True, |
|
) -> None: |
|
""" |
|
Visualize the annotations on the annotated_frame_id. |
|
If show_manual is True, show the manual annotations. |
|
If show_auto is True, show the auto annotations. |
|
By default, show both auto and manual annotations. |
|
""" |
|
|
|
if annotated_frame_id >= len(frames): |
|
print("invalid annotated_frame_id") |
|
return |
|
|
|
rles = [] |
|
colors = [] |
|
if show_manual and manual_annot is not None: |
|
rles.extend(manual_annot["masklet"][annotated_frame_id]) |
|
colors.extend( |
|
self.manual_mask_colors[ |
|
: len(manual_annot["masklet"][annotated_frame_id]) |
|
] |
|
) |
|
if show_auto and auto_annot is not None: |
|
rles.extend(auto_annot["masklet"][annotated_frame_id]) |
|
colors.extend( |
|
self.auto_mask_colors[: len(auto_annot["masklet"][annotated_frame_id])] |
|
) |
|
|
|
plt.imshow(frames[annotated_frame_id]) |
|
|
|
if len(rles) > 0: |
|
masks = [mask_util.decode(rle) > 0 for rle in rles] |
|
show_anns(masks, colors) |
|
else: |
|
print("No annotation will be shown") |
|
|
|
plt.axis("off") |
|
plt.show() |
|
|