maffia's picture
Upload 94 files
690f890 verified
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import numpy as np
import torch
from scipy import ndimage
from .utils import convert_to_numpy, read_video_one_frame, single_mask_to_rle, single_rle_to_mask, single_mask_to_xyxy
class SAM2ImageAnnotator:
def __init__(self, cfg, device=None):
self.task_type = cfg.get('TASK_TYPE', 'input_box')
self.return_mask = cfg.get('RETURN_MASK', False)
try:
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
except:
import warnings
warnings.warn("please pip install sam2 package, or you can refer to models/VACE-Annotators/sam2/SAM_2-1.0-cp310-cp310-linux_x86_64.whl")
config_path = cfg['CONFIG_PATH']
local_config_path = os.path.join(*config_path.rsplit('/')[-3:])
if not os.path.exists(local_config_path): # TODO
os.makedirs(os.path.dirname(local_config_path), exist_ok=True)
shutil.copy(config_path, local_config_path)
pretrained_model = cfg['PRETRAINED_MODEL']
sam2_model = build_sam2(local_config_path, pretrained_model)
self.predictor = SAM2ImagePredictor(sam2_model)
self.predictor.fill_hole_area = 0
def forward(self,
image,
input_box=None,
mask=None,
task_type=None,
return_mask=None):
task_type = task_type if task_type is not None else self.task_type
return_mask = return_mask if return_mask is not None else self.return_mask
mask = convert_to_numpy(mask) if mask is not None else None
if task_type == 'mask_point':
if len(mask.shape) == 3:
scribble = mask.transpose(2, 1, 0)[0]
else:
scribble = mask.transpose(1, 0) # (H, W) -> (W, H)
labeled_array, num_features = ndimage.label(scribble >= 255)
centers = ndimage.center_of_mass(scribble, labeled_array,
range(1, num_features + 1))
point_coords = np.array(centers)
point_labels = np.array([1] * len(centers))
sample = {
'point_coords': point_coords,
'point_labels': point_labels
}
elif task_type == 'mask_box':
if len(mask.shape) == 3:
scribble = mask.transpose(2, 1, 0)[0]
else:
scribble = mask.transpose(1, 0) # (H, W) -> (W, H)
labeled_array, num_features = ndimage.label(scribble >= 255)
centers = ndimage.center_of_mass(scribble, labeled_array,
range(1, num_features + 1))
centers = np.array(centers)
# (x1, y1, x2, y2)
x_min = centers[:, 0].min()
x_max = centers[:, 0].max()
y_min = centers[:, 1].min()
y_max = centers[:, 1].max()
bbox = np.array([x_min, y_min, x_max, y_max])
sample = {'box': bbox}
elif task_type == 'input_box':
if isinstance(input_box, list):
input_box = np.array(input_box)
sample = {'box': input_box}
elif task_type == 'mask':
sample = {'mask_input': mask[None, :, :]}
else:
raise NotImplementedError
self.predictor.set_image(image)
masks, scores, logits = self.predictor.predict(
multimask_output=False,
**sample
)
sorted_ind = np.argsort(scores)[::-1]
masks = masks[sorted_ind]
scores = scores[sorted_ind]
logits = logits[sorted_ind]
if return_mask:
return masks[0]
else:
ret_data = {
"masks": masks,
"scores": scores,
"logits": logits
}
return ret_data
class SAM2VideoAnnotator:
def __init__(self, cfg, device=None):
self.task_type = cfg.get('TASK_TYPE', 'input_box')
try:
from sam2.build_sam import build_sam2_video_predictor
except:
import warnings
warnings.warn("please pip install sam2 package, or you can refer to models/VACE-Annotators/sam2/SAM_2-1.0-cp310-cp310-linux_x86_64.whl")
config_path = cfg['CONFIG_PATH']
local_config_path = os.path.join(*config_path.rsplit('/')[-3:])
if not os.path.exists(local_config_path): # TODO
os.makedirs(os.path.dirname(local_config_path), exist_ok=True)
shutil.copy(config_path, local_config_path)
pretrained_model = cfg['PRETRAINED_MODEL']
self.video_predictor = build_sam2_video_predictor(local_config_path, pretrained_model)
self.video_predictor.fill_hole_area = 0
def forward(self,
video,
input_box=None,
mask=None,
task_type=None):
task_type = task_type if task_type is not None else self.task_type
mask = convert_to_numpy(mask) if mask is not None else None
if task_type == 'mask_point':
if len(mask.shape) == 3:
scribble = mask.transpose(2, 1, 0)[0]
else:
scribble = mask.transpose(1, 0) # (H, W) -> (W, H)
labeled_array, num_features = ndimage.label(scribble >= 255)
centers = ndimage.center_of_mass(scribble, labeled_array,
range(1, num_features + 1))
point_coords = np.array(centers)
point_labels = np.array([1] * len(centers))
sample = {
'points': point_coords,
'labels': point_labels
}
elif task_type == 'mask_box':
if len(mask.shape) == 3:
scribble = mask.transpose(2, 1, 0)[0]
else:
scribble = mask.transpose(1, 0) # (H, W) -> (W, H)
labeled_array, num_features = ndimage.label(scribble >= 255)
centers = ndimage.center_of_mass(scribble, labeled_array,
range(1, num_features + 1))
centers = np.array(centers)
# (x1, y1, x2, y2)
x_min = centers[:, 0].min()
x_max = centers[:, 0].max()
y_min = centers[:, 1].min()
y_max = centers[:, 1].max()
bbox = np.array([x_min, y_min, x_max, y_max])
sample = {'box': bbox}
elif task_type == 'input_box':
if isinstance(input_box, list):
input_box = np.array(input_box)
sample = {'box': input_box}
elif task_type == 'mask':
sample = {'mask': mask}
else:
raise NotImplementedError
ann_frame_idx = 0
object_id = 0
with (torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16)):
inference_state = self.video_predictor.init_state(video_path=video)
if task_type in ['mask_point', 'mask_box', 'input_box']:
_, out_obj_ids, out_mask_logits = self.video_predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=object_id,
**sample
)
elif task_type in ['mask']:
_, out_obj_ids, out_mask_logits = self.video_predictor.add_new_mask(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=object_id,
**sample
)
else:
raise NotImplementedError
video_segments = {} # video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in self.video_predictor.propagate_in_video(inference_state):
frame_segments = {}
for i, out_obj_id in enumerate(out_obj_ids):
mask = (out_mask_logits[i] > 0.0).cpu().numpy().squeeze(0)
frame_segments[out_obj_id] = {
"mask": single_mask_to_rle(mask),
"mask_area": int(mask.sum()),
"mask_box": single_mask_to_xyxy(mask),
}
video_segments[out_frame_idx] = frame_segments
ret_data = {
"annotations": video_segments
}
return ret_data
class SAM2SalientVideoAnnotator:
def __init__(self, cfg, device=None):
from .salient import SalientAnnotator
from .sam2 import SAM2VideoAnnotator
self.salient_model = SalientAnnotator(cfg['SALIENT'], device=device)
self.sam2_model = SAM2VideoAnnotator(cfg['SAM2'], device=device)
def forward(self, video, image=None):
if image is None:
image = read_video_one_frame(video)
else:
image = convert_to_numpy(image)
salient_res = self.salient_model.forward(image)
sam2_res = self.sam2_model.forward(video=video, mask=salient_res, task_type='mask')
return sam2_res
class SAM2GDINOVideoAnnotator:
def __init__(self, cfg, device=None):
from .gdino import GDINOAnnotator
from .sam2 import SAM2VideoAnnotator
self.gdino_model = GDINOAnnotator(cfg['GDINO'], device=device)
self.sam2_model = SAM2VideoAnnotator(cfg['SAM2'], device=device)
def forward(self, video, image=None, classes=None, caption=None):
if image is None:
image = read_video_one_frame(video)
else:
image = convert_to_numpy(image)
if classes is not None:
gdino_res = self.gdino_model.forward(image, classes=classes)
else:
gdino_res = self.gdino_model.forward(image, caption=caption)
if 'boxes' in gdino_res and len(gdino_res['boxes']) > 0:
bboxes = gdino_res['boxes'][0]
else:
raise ValueError("Unable to find the corresponding boxes")
sam2_res = self.sam2_model.forward(video=video, input_box=bboxes, task_type='input_box')
return sam2_res