|
|
|
|
|
import math |
|
import random |
|
from abc import ABCMeta |
|
|
|
import numpy as np |
|
import torch |
|
from PIL import Image, ImageDraw |
|
from .utils import convert_to_pil, get_mask_box |
|
|
|
|
|
class OutpaintingAnnotator: |
|
def __init__(self, cfg, device=None): |
|
self.mask_blur = cfg.get('MASK_BLUR', 0) |
|
self.random_cfg = cfg.get('RANDOM_CFG', None) |
|
self.return_mask = cfg.get('RETURN_MASK', False) |
|
self.return_source = cfg.get('RETURN_SOURCE', True) |
|
self.keep_padding_ratio = cfg.get('KEEP_PADDING_RATIO', 8) |
|
self.mask_color = cfg.get('MASK_COLOR', 0) |
|
|
|
def forward(self, |
|
image, |
|
expand_ratio=0.3, |
|
mask=None, |
|
direction=['left', 'right', 'up', 'down'], |
|
return_mask=None, |
|
return_source=None, |
|
mask_color=None): |
|
return_mask = return_mask if return_mask is not None else self.return_mask |
|
return_source = return_source if return_source is not None else self.return_source |
|
mask_color = mask_color if mask_color is not None else self.mask_color |
|
image = convert_to_pil(image) |
|
if self.random_cfg: |
|
direction_range = self.random_cfg.get( |
|
'DIRECTION_RANGE', ['left', 'right', 'up', 'down']) |
|
ratio_range = self.random_cfg.get('RATIO_RANGE', [0.0, 1.0]) |
|
direction = random.sample( |
|
direction_range, |
|
random.choice(list(range(1, |
|
len(direction_range) + 1)))) |
|
expand_ratio = random.uniform(ratio_range[0], ratio_range[1]) |
|
|
|
if mask is None: |
|
init_image = image |
|
src_width, src_height = init_image.width, init_image.height |
|
left = int(expand_ratio * src_width) if 'left' in direction else 0 |
|
right = int(expand_ratio * src_width) if 'right' in direction else 0 |
|
up = int(expand_ratio * src_height) if 'up' in direction else 0 |
|
down = int(expand_ratio * src_height) if 'down' in direction else 0 |
|
tar_width = math.ceil( |
|
(src_width + left + right) / |
|
self.keep_padding_ratio) * self.keep_padding_ratio |
|
tar_height = math.ceil( |
|
(src_height + up + down) / |
|
self.keep_padding_ratio) * self.keep_padding_ratio |
|
if left > 0: |
|
left = left * (tar_width - src_width) // (left + right) |
|
if right > 0: |
|
right = tar_width - src_width - left |
|
if up > 0: |
|
up = up * (tar_height - src_height) // (up + down) |
|
if down > 0: |
|
down = tar_height - src_height - up |
|
if mask_color is not None: |
|
img = Image.new('RGB', (tar_width, tar_height), |
|
color=mask_color) |
|
else: |
|
img = Image.new('RGB', (tar_width, tar_height)) |
|
img.paste(init_image, (left, up)) |
|
mask = Image.new('L', (img.width, img.height), 'white') |
|
draw = ImageDraw.Draw(mask) |
|
|
|
draw.rectangle( |
|
(left + (self.mask_blur * 2 if left > 0 else 0), up + |
|
(self.mask_blur * 2 if up > 0 else 0), mask.width - right - |
|
(self.mask_blur * 2 if right > 0 else 0) - 1, mask.height - down - |
|
(self.mask_blur * 2 if down > 0 else 0) - 1), |
|
fill='black') |
|
else: |
|
bbox = get_mask_box(np.array(mask)) |
|
if bbox is None: |
|
img = image |
|
mask = mask |
|
init_image = image |
|
else: |
|
mask = Image.new('L', (image.width, image.height), 'white') |
|
mask_zero = Image.new('L', |
|
(bbox[2] - bbox[0], bbox[3] - bbox[1]), |
|
'black') |
|
mask.paste(mask_zero, (bbox[0], bbox[1])) |
|
crop_image = image.crop(bbox) |
|
init_image = Image.new('RGB', (image.width, image.height), |
|
'black') |
|
init_image.paste(crop_image, (bbox[0], bbox[1])) |
|
img = image |
|
if return_mask: |
|
if return_source: |
|
ret_data = { |
|
'src_image': np.array(init_image), |
|
'image': np.array(img), |
|
'mask': np.array(mask) |
|
} |
|
else: |
|
ret_data = {'image': np.array(img), 'mask': np.array(mask)} |
|
else: |
|
if return_source: |
|
ret_data = { |
|
'src_image': np.array(init_image), |
|
'image': np.array(img) |
|
} |
|
else: |
|
ret_data = np.array(img) |
|
return ret_data |
|
|
|
|
|
|
|
class OutpaintingInnerAnnotator: |
|
def __init__(self, cfg, device=None): |
|
self.mask_blur = cfg.get('MASK_BLUR', 0) |
|
self.random_cfg = cfg.get('RANDOM_CFG', None) |
|
self.return_mask = cfg.get('RETURN_MASK', False) |
|
self.return_source = cfg.get('RETURN_SOURCE', True) |
|
self.keep_padding_ratio = cfg.get('KEEP_PADDING_RATIO', 8) |
|
self.mask_color = cfg.get('MASK_COLOR', 0) |
|
|
|
def forward(self, |
|
image, |
|
expand_ratio=0.3, |
|
direction=['left', 'right', 'up', 'down'], |
|
return_mask=None, |
|
return_source=None, |
|
mask_color=None): |
|
return_mask = return_mask if return_mask is not None else self.return_mask |
|
return_source = return_source if return_source is not None else self.return_source |
|
mask_color = mask_color if mask_color is not None else self.mask_color |
|
image = convert_to_pil(image) |
|
if self.random_cfg: |
|
direction_range = self.random_cfg.get( |
|
'DIRECTION_RANGE', ['left', 'right', 'up', 'down']) |
|
ratio_range = self.random_cfg.get('RATIO_RANGE', [0.0, 1.0]) |
|
direction = random.sample( |
|
direction_range, |
|
random.choice(list(range(1, |
|
len(direction_range) + 1)))) |
|
expand_ratio = random.uniform(ratio_range[0], ratio_range[1]) |
|
|
|
init_image = image |
|
src_width, src_height = init_image.width, init_image.height |
|
left = int(expand_ratio * src_width) if 'left' in direction else 0 |
|
right = int(expand_ratio * src_width) if 'right' in direction else 0 |
|
up = int(expand_ratio * src_height) if 'up' in direction else 0 |
|
down = int(expand_ratio * src_height) if 'down' in direction else 0 |
|
|
|
crop_left = left |
|
crop_right = src_width - right |
|
crop_up = up |
|
crop_down = src_height - down |
|
crop_box = (crop_left, crop_up, crop_right, crop_down) |
|
cropped_image = init_image.crop(crop_box) |
|
if mask_color is not None: |
|
img = Image.new('RGB', (src_width, src_height), color=mask_color) |
|
else: |
|
img = Image.new('RGB', (src_width, src_height)) |
|
|
|
paste_x = left |
|
paste_y = up |
|
img.paste(cropped_image, (paste_x, paste_y)) |
|
|
|
mask = Image.new('L', (img.width, img.height), 'white') |
|
draw = ImageDraw.Draw(mask) |
|
|
|
x0 = paste_x + (self.mask_blur * 2 if left > 0 else 0) |
|
y0 = paste_y + (self.mask_blur * 2 if up > 0 else 0) |
|
x1 = paste_x + cropped_image.width - (self.mask_blur * 2 if right > 0 else 0) |
|
y1 = paste_y + cropped_image.height - (self.mask_blur * 2 if down > 0 else 0) |
|
draw.rectangle((x0, y0, x1, y1), fill='black') |
|
|
|
if return_mask: |
|
if return_source: |
|
ret_data = { |
|
'src_image': np.array(init_image), |
|
'image': np.array(img), |
|
'mask': np.array(mask) |
|
} |
|
else: |
|
ret_data = {'image': np.array(img), 'mask': np.array(mask)} |
|
else: |
|
if return_source: |
|
ret_data = { |
|
'src_image': np.array(init_image), |
|
'image': np.array(img) |
|
} |
|
else: |
|
ret_data = np.array(img) |
|
return ret_data |
|
|
|
|
|
|
|
|
|
|
|
class OutpaintingVideoAnnotator(OutpaintingAnnotator): |
|
|
|
def __init__(self, cfg, device=None): |
|
super().__init__(cfg, device) |
|
self.key_map = { |
|
"src_image": "src_images", |
|
"image" : "frames", |
|
"mask": "masks" |
|
} |
|
|
|
def forward(self, frames, |
|
expand_ratio=0.3, |
|
mask=None, |
|
direction=['left', 'right', 'up', 'down'], |
|
return_mask=None, |
|
return_source=None, |
|
mask_color=None): |
|
ret_frames = None |
|
for frame in frames: |
|
anno_frame = super().forward(frame, expand_ratio=expand_ratio, mask=mask, direction=direction, return_mask=return_mask, return_source=return_source, mask_color=mask_color) |
|
if isinstance(anno_frame, dict): |
|
ret_frames = {} if ret_frames is None else ret_frames |
|
for key, val in anno_frame.items(): |
|
new_key = self.key_map[key] |
|
if new_key in ret_frames: |
|
ret_frames[new_key].append(val) |
|
else: |
|
ret_frames[new_key] = [val] |
|
else: |
|
ret_frames = [] if ret_frames is None else ret_frames |
|
ret_frames.append(anno_frame) |
|
return ret_frames |
|
|
|
|
|
class OutpaintingInnerVideoAnnotator(OutpaintingInnerAnnotator): |
|
|
|
def __init__(self, cfg, device=None): |
|
super().__init__(cfg, device) |
|
self.key_map = { |
|
"src_image": "src_images", |
|
"image" : "frames", |
|
"mask": "masks" |
|
} |
|
|
|
def forward(self, frames, |
|
expand_ratio=0.3, |
|
direction=['left', 'right', 'up', 'down'], |
|
return_mask=None, |
|
return_source=None, |
|
mask_color=None): |
|
ret_frames = None |
|
for frame in frames: |
|
anno_frame = super().forward(frame, expand_ratio=expand_ratio, direction=direction, return_mask=return_mask, return_source=return_source, mask_color=mask_color) |
|
if isinstance(anno_frame, dict): |
|
ret_frames = {} if ret_frames is None else ret_frames |
|
for key, val in anno_frame.items(): |
|
new_key = self.key_map[key] |
|
if new_key in ret_frames: |
|
ret_frames[new_key].append(val) |
|
else: |
|
ret_frames[new_key] = [val] |
|
else: |
|
ret_frames = [] if ret_frames is None else ret_frames |
|
ret_frames.append(anno_frame) |
|
return ret_frames |
|
|