maffia's picture
Upload 94 files
690f890 verified
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
from scipy.spatial import ConvexHull
from skimage.draw import polygon
from scipy import ndimage
from .utils import convert_to_numpy
class MaskDrawAnnotator:
def __init__(self, cfg, device=None):
self.mode = cfg.get('MODE', 'maskpoint')
self.return_dict = cfg.get('RETURN_DICT', True)
assert self.mode in ['maskpoint', 'maskbbox', 'mask', 'bbox']
def forward(self,
mask=None,
image=None,
bbox=None,
mode=None,
return_dict=None):
mode = mode if mode is not None else self.mode
return_dict = return_dict if return_dict is not None else self.return_dict
mask = convert_to_numpy(mask) if mask is not None else None
image = convert_to_numpy(image) if image is not None else None
mask_shape = mask.shape
if mode == 'maskpoint':
scribble = mask.transpose(1, 0)
labeled_array, num_features = ndimage.label(scribble >= 255)
centers = ndimage.center_of_mass(scribble, labeled_array,
range(1, num_features + 1))
centers = np.array(centers)
out_mask = np.zeros(mask_shape, dtype=np.uint8)
hull = ConvexHull(centers)
hull_vertices = centers[hull.vertices]
rr, cc = polygon(hull_vertices[:, 1], hull_vertices[:, 0], mask_shape)
out_mask[rr, cc] = 255
elif mode == 'maskbbox':
scribble = mask.transpose(1, 0)
labeled_array, num_features = ndimage.label(scribble >= 255)
centers = ndimage.center_of_mass(scribble, labeled_array,
range(1, num_features + 1))
centers = np.array(centers)
# (x1, y1, x2, y2)
x_min = centers[:, 0].min()
x_max = centers[:, 0].max()
y_min = centers[:, 1].min()
y_max = centers[:, 1].max()
out_mask = np.zeros(mask_shape, dtype=np.uint8)
out_mask[int(y_min) : int(y_max) + 1, int(x_min) : int(x_max) + 1] = 255
if image is not None:
out_image = image[int(y_min) : int(y_max) + 1, int(x_min) : int(x_max) + 1]
elif mode == 'bbox':
if isinstance(bbox, list):
bbox = np.array(bbox)
x_min, y_min, x_max, y_max = bbox
out_mask = np.zeros(mask_shape, dtype=np.uint8)
out_mask[int(y_min) : int(y_max) + 1, int(x_min) : int(x_max) + 1] = 255
if image is not None:
out_image = image[int(y_min) : int(y_max) + 1, int(x_min) : int(x_max) + 1]
elif mode == 'mask':
out_mask = mask
else:
raise NotImplementedError
if return_dict:
if image is not None:
return {"image": out_image, "mask": out_mask}
else:
return {"mask": out_mask}
else:
if image is not None:
return out_image, out_mask
else:
return out_mask