|
|
|
|
|
|
|
import numpy as np |
|
from scipy.spatial import ConvexHull |
|
from skimage.draw import polygon |
|
from scipy import ndimage |
|
|
|
from .utils import convert_to_numpy |
|
|
|
|
|
class MaskDrawAnnotator: |
|
def __init__(self, cfg, device=None): |
|
self.mode = cfg.get('MODE', 'maskpoint') |
|
self.return_dict = cfg.get('RETURN_DICT', True) |
|
assert self.mode in ['maskpoint', 'maskbbox', 'mask', 'bbox'] |
|
|
|
def forward(self, |
|
mask=None, |
|
image=None, |
|
bbox=None, |
|
mode=None, |
|
return_dict=None): |
|
mode = mode if mode is not None else self.mode |
|
return_dict = return_dict if return_dict is not None else self.return_dict |
|
|
|
mask = convert_to_numpy(mask) if mask is not None else None |
|
image = convert_to_numpy(image) if image is not None else None |
|
|
|
mask_shape = mask.shape |
|
if mode == 'maskpoint': |
|
scribble = mask.transpose(1, 0) |
|
labeled_array, num_features = ndimage.label(scribble >= 255) |
|
centers = ndimage.center_of_mass(scribble, labeled_array, |
|
range(1, num_features + 1)) |
|
centers = np.array(centers) |
|
out_mask = np.zeros(mask_shape, dtype=np.uint8) |
|
hull = ConvexHull(centers) |
|
hull_vertices = centers[hull.vertices] |
|
rr, cc = polygon(hull_vertices[:, 1], hull_vertices[:, 0], mask_shape) |
|
out_mask[rr, cc] = 255 |
|
elif mode == 'maskbbox': |
|
scribble = mask.transpose(1, 0) |
|
labeled_array, num_features = ndimage.label(scribble >= 255) |
|
centers = ndimage.center_of_mass(scribble, labeled_array, |
|
range(1, num_features + 1)) |
|
centers = np.array(centers) |
|
|
|
x_min = centers[:, 0].min() |
|
x_max = centers[:, 0].max() |
|
y_min = centers[:, 1].min() |
|
y_max = centers[:, 1].max() |
|
out_mask = np.zeros(mask_shape, dtype=np.uint8) |
|
out_mask[int(y_min) : int(y_max) + 1, int(x_min) : int(x_max) + 1] = 255 |
|
if image is not None: |
|
out_image = image[int(y_min) : int(y_max) + 1, int(x_min) : int(x_max) + 1] |
|
elif mode == 'bbox': |
|
if isinstance(bbox, list): |
|
bbox = np.array(bbox) |
|
x_min, y_min, x_max, y_max = bbox |
|
out_mask = np.zeros(mask_shape, dtype=np.uint8) |
|
out_mask[int(y_min) : int(y_max) + 1, int(x_min) : int(x_max) + 1] = 255 |
|
if image is not None: |
|
out_image = image[int(y_min) : int(y_max) + 1, int(x_min) : int(x_max) + 1] |
|
elif mode == 'mask': |
|
out_mask = mask |
|
else: |
|
raise NotImplementedError |
|
|
|
if return_dict: |
|
if image is not None: |
|
return {"image": out_image, "mask": out_mask} |
|
else: |
|
return {"mask": out_mask} |
|
else: |
|
if image is not None: |
|
return out_image, out_mask |
|
else: |
|
return out_mask |