vace-demo / vace /annotators /outpainting.py
maffia's picture
Upload 94 files
690f890 verified
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import random
from abc import ABCMeta
import numpy as np
import torch
from PIL import Image, ImageDraw
from .utils import convert_to_pil, get_mask_box
class OutpaintingAnnotator:
def __init__(self, cfg, device=None):
self.mask_blur = cfg.get('MASK_BLUR', 0)
self.random_cfg = cfg.get('RANDOM_CFG', None)
self.return_mask = cfg.get('RETURN_MASK', False)
self.return_source = cfg.get('RETURN_SOURCE', True)
self.keep_padding_ratio = cfg.get('KEEP_PADDING_RATIO', 8)
self.mask_color = cfg.get('MASK_COLOR', 0)
def forward(self,
image,
expand_ratio=0.3,
mask=None,
direction=['left', 'right', 'up', 'down'],
return_mask=None,
return_source=None,
mask_color=None):
return_mask = return_mask if return_mask is not None else self.return_mask
return_source = return_source if return_source is not None else self.return_source
mask_color = mask_color if mask_color is not None else self.mask_color
image = convert_to_pil(image)
if self.random_cfg:
direction_range = self.random_cfg.get(
'DIRECTION_RANGE', ['left', 'right', 'up', 'down'])
ratio_range = self.random_cfg.get('RATIO_RANGE', [0.0, 1.0])
direction = random.sample(
direction_range,
random.choice(list(range(1,
len(direction_range) + 1))))
expand_ratio = random.uniform(ratio_range[0], ratio_range[1])
if mask is None:
init_image = image
src_width, src_height = init_image.width, init_image.height
left = int(expand_ratio * src_width) if 'left' in direction else 0
right = int(expand_ratio * src_width) if 'right' in direction else 0
up = int(expand_ratio * src_height) if 'up' in direction else 0
down = int(expand_ratio * src_height) if 'down' in direction else 0
tar_width = math.ceil(
(src_width + left + right) /
self.keep_padding_ratio) * self.keep_padding_ratio
tar_height = math.ceil(
(src_height + up + down) /
self.keep_padding_ratio) * self.keep_padding_ratio
if left > 0:
left = left * (tar_width - src_width) // (left + right)
if right > 0:
right = tar_width - src_width - left
if up > 0:
up = up * (tar_height - src_height) // (up + down)
if down > 0:
down = tar_height - src_height - up
if mask_color is not None:
img = Image.new('RGB', (tar_width, tar_height),
color=mask_color)
else:
img = Image.new('RGB', (tar_width, tar_height))
img.paste(init_image, (left, up))
mask = Image.new('L', (img.width, img.height), 'white')
draw = ImageDraw.Draw(mask)
draw.rectangle(
(left + (self.mask_blur * 2 if left > 0 else 0), up +
(self.mask_blur * 2 if up > 0 else 0), mask.width - right -
(self.mask_blur * 2 if right > 0 else 0) - 1, mask.height - down -
(self.mask_blur * 2 if down > 0 else 0) - 1),
fill='black')
else:
bbox = get_mask_box(np.array(mask))
if bbox is None:
img = image
mask = mask
init_image = image
else:
mask = Image.new('L', (image.width, image.height), 'white')
mask_zero = Image.new('L',
(bbox[2] - bbox[0], bbox[3] - bbox[1]),
'black')
mask.paste(mask_zero, (bbox[0], bbox[1]))
crop_image = image.crop(bbox)
init_image = Image.new('RGB', (image.width, image.height),
'black')
init_image.paste(crop_image, (bbox[0], bbox[1]))
img = image
if return_mask:
if return_source:
ret_data = {
'src_image': np.array(init_image),
'image': np.array(img),
'mask': np.array(mask)
}
else:
ret_data = {'image': np.array(img), 'mask': np.array(mask)}
else:
if return_source:
ret_data = {
'src_image': np.array(init_image),
'image': np.array(img)
}
else:
ret_data = np.array(img)
return ret_data
class OutpaintingInnerAnnotator:
def __init__(self, cfg, device=None):
self.mask_blur = cfg.get('MASK_BLUR', 0)
self.random_cfg = cfg.get('RANDOM_CFG', None)
self.return_mask = cfg.get('RETURN_MASK', False)
self.return_source = cfg.get('RETURN_SOURCE', True)
self.keep_padding_ratio = cfg.get('KEEP_PADDING_RATIO', 8)
self.mask_color = cfg.get('MASK_COLOR', 0)
def forward(self,
image,
expand_ratio=0.3,
direction=['left', 'right', 'up', 'down'],
return_mask=None,
return_source=None,
mask_color=None):
return_mask = return_mask if return_mask is not None else self.return_mask
return_source = return_source if return_source is not None else self.return_source
mask_color = mask_color if mask_color is not None else self.mask_color
image = convert_to_pil(image)
if self.random_cfg:
direction_range = self.random_cfg.get(
'DIRECTION_RANGE', ['left', 'right', 'up', 'down'])
ratio_range = self.random_cfg.get('RATIO_RANGE', [0.0, 1.0])
direction = random.sample(
direction_range,
random.choice(list(range(1,
len(direction_range) + 1))))
expand_ratio = random.uniform(ratio_range[0], ratio_range[1])
init_image = image
src_width, src_height = init_image.width, init_image.height
left = int(expand_ratio * src_width) if 'left' in direction else 0
right = int(expand_ratio * src_width) if 'right' in direction else 0
up = int(expand_ratio * src_height) if 'up' in direction else 0
down = int(expand_ratio * src_height) if 'down' in direction else 0
crop_left = left
crop_right = src_width - right
crop_up = up
crop_down = src_height - down
crop_box = (crop_left, crop_up, crop_right, crop_down)
cropped_image = init_image.crop(crop_box)
if mask_color is not None:
img = Image.new('RGB', (src_width, src_height), color=mask_color)
else:
img = Image.new('RGB', (src_width, src_height))
paste_x = left
paste_y = up
img.paste(cropped_image, (paste_x, paste_y))
mask = Image.new('L', (img.width, img.height), 'white')
draw = ImageDraw.Draw(mask)
x0 = paste_x + (self.mask_blur * 2 if left > 0 else 0)
y0 = paste_y + (self.mask_blur * 2 if up > 0 else 0)
x1 = paste_x + cropped_image.width - (self.mask_blur * 2 if right > 0 else 0)
y1 = paste_y + cropped_image.height - (self.mask_blur * 2 if down > 0 else 0)
draw.rectangle((x0, y0, x1, y1), fill='black')
if return_mask:
if return_source:
ret_data = {
'src_image': np.array(init_image),
'image': np.array(img),
'mask': np.array(mask)
}
else:
ret_data = {'image': np.array(img), 'mask': np.array(mask)}
else:
if return_source:
ret_data = {
'src_image': np.array(init_image),
'image': np.array(img)
}
else:
ret_data = np.array(img)
return ret_data
class OutpaintingVideoAnnotator(OutpaintingAnnotator):
def __init__(self, cfg, device=None):
super().__init__(cfg, device)
self.key_map = {
"src_image": "src_images",
"image" : "frames",
"mask": "masks"
}
def forward(self, frames,
expand_ratio=0.3,
mask=None,
direction=['left', 'right', 'up', 'down'],
return_mask=None,
return_source=None,
mask_color=None):
ret_frames = None
for frame in frames:
anno_frame = super().forward(frame, expand_ratio=expand_ratio, mask=mask, direction=direction, return_mask=return_mask, return_source=return_source, mask_color=mask_color)
if isinstance(anno_frame, dict):
ret_frames = {} if ret_frames is None else ret_frames
for key, val in anno_frame.items():
new_key = self.key_map[key]
if new_key in ret_frames:
ret_frames[new_key].append(val)
else:
ret_frames[new_key] = [val]
else:
ret_frames = [] if ret_frames is None else ret_frames
ret_frames.append(anno_frame)
return ret_frames
class OutpaintingInnerVideoAnnotator(OutpaintingInnerAnnotator):
def __init__(self, cfg, device=None):
super().__init__(cfg, device)
self.key_map = {
"src_image": "src_images",
"image" : "frames",
"mask": "masks"
}
def forward(self, frames,
expand_ratio=0.3,
direction=['left', 'right', 'up', 'down'],
return_mask=None,
return_source=None,
mask_color=None):
ret_frames = None
for frame in frames:
anno_frame = super().forward(frame, expand_ratio=expand_ratio, direction=direction, return_mask=return_mask, return_source=return_source, mask_color=mask_color)
if isinstance(anno_frame, dict):
ret_frames = {} if ret_frames is None else ret_frames
for key, val in anno_frame.items():
new_key = self.key_map[key]
if new_key in ret_frames:
ret_frames[new_key].append(val)
else:
ret_frames[new_key] = [val]
else:
ret_frames = [] if ret_frames is None else ret_frames
ret_frames.append(anno_frame)
return ret_frames