vace-demo / vace /annotators /inpainting.py
maffia's picture
Upload 94 files
690f890 verified
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import cv2
import math
import random
from abc import ABCMeta
import numpy as np
import torch
from PIL import Image, ImageDraw
from .utils import convert_to_numpy, convert_to_pil, single_rle_to_mask, get_mask_box, read_video_one_frame
class InpaintingAnnotator:
def __init__(self, cfg, device=None):
self.use_aug = cfg.get('USE_AUG', True)
self.return_mask = cfg.get('RETURN_MASK', True)
self.return_source = cfg.get('RETURN_SOURCE', True)
self.mask_color = cfg.get('MASK_COLOR', 128)
self.mode = cfg.get('MODE', "mask")
assert self.mode in ["salient", "mask", "bbox", "salientmasktrack", "salientbboxtrack", "maskpointtrack", "maskbboxtrack", "masktrack", "bboxtrack", "label", "caption", "all"]
if self.mode in ["salient", "salienttrack"]:
from .salient import SalientAnnotator
self.salient_model = SalientAnnotator(cfg['SALIENT'], device=device)
if self.mode in ['masktrack', 'bboxtrack', 'salienttrack']:
from .sam2 import SAM2ImageAnnotator
self.sam2_model = SAM2ImageAnnotator(cfg['SAM2'], device=device)
if self.mode in ['label', 'caption']:
from .gdino import GDINOAnnotator
from .sam2 import SAM2ImageAnnotator
self.gdino_model = GDINOAnnotator(cfg['GDINO'], device=device)
self.sam2_model = SAM2ImageAnnotator(cfg['SAM2'], device=device)
if self.mode in ['all']:
from .salient import SalientAnnotator
from .gdino import GDINOAnnotator
from .sam2 import SAM2ImageAnnotator
self.salient_model = SalientAnnotator(cfg['SALIENT'], device=device)
self.gdino_model = GDINOAnnotator(cfg['GDINO'], device=device)
self.sam2_model = SAM2ImageAnnotator(cfg['SAM2'], device=device)
if self.use_aug:
from .maskaug import MaskAugAnnotator
self.maskaug_anno = MaskAugAnnotator(cfg={})
def apply_plain_mask(self, image, mask, mask_color):
bool_mask = mask > 0
out_image = image.copy()
out_image[bool_mask] = mask_color
out_mask = np.where(bool_mask, 255, 0).astype(np.uint8)
return out_image, out_mask
def apply_seg_mask(self, image, mask, mask_color, mask_cfg=None):
out_mask = (mask * 255).astype('uint8')
if self.use_aug and mask_cfg is not None:
out_mask = self.maskaug_anno.forward(out_mask, mask_cfg)
bool_mask = out_mask > 0
out_image = image.copy()
out_image[bool_mask] = mask_color
return out_image, out_mask
def forward(self, image=None, mask=None, bbox=None, label=None, caption=None, mode=None, return_mask=None, return_source=None, mask_color=None, mask_cfg=None):
mode = mode if mode is not None else self.mode
return_mask = return_mask if return_mask is not None else self.return_mask
return_source = return_source if return_source is not None else self.return_source
mask_color = mask_color if mask_color is not None else self.mask_color
image = convert_to_numpy(image)
out_image, out_mask = None, None
if mode in ['salient']:
mask = self.salient_model.forward(image)
out_image, out_mask = self.apply_plain_mask(image, mask, mask_color)
elif mode in ['mask']:
mask_h, mask_w = mask.shape[:2]
h, w = image.shape[:2]
if (mask_h ==h) and (mask_w == w):
mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
out_image, out_mask = self.apply_plain_mask(image, mask, mask_color)
elif mode in ['bbox']:
x1, y1, x2, y2 = bbox
h, w = image.shape[:2]
x1, y1 = int(max(0, x1)), int(max(0, y1))
x2, y2 = int(min(w, x2)), int(min(h, y2))
out_image = image.copy()
out_image[y1:y2, x1:x2] = mask_color
out_mask = np.zeros((h, w), dtype=np.uint8)
out_mask[y1:y2, x1:x2] = 255
elif mode in ['salientmasktrack']:
mask = self.salient_model.forward(image)
resize_mask = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_NEAREST)
out_mask = self.sam2_model.forward(image=image, mask=resize_mask, task_type='mask', return_mask=True)
out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg)
elif mode in ['salientbboxtrack']:
mask = self.salient_model.forward(image)
bbox = get_mask_box(np.array(mask), threshold=1)
out_mask = self.sam2_model.forward(image=image, input_box=bbox, task_type='input_box', return_mask=True)
out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg)
elif mode in ['maskpointtrack']:
out_mask = self.sam2_model.forward(image=image, mask=mask, task_type='mask_point', return_mask=True)
out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg)
elif mode in ['maskbboxtrack']:
out_mask = self.sam2_model.forward(image=image, mask=mask, task_type='mask_box', return_mask=True)
out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg)
elif mode in ['masktrack']:
resize_mask = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_NEAREST)
out_mask = self.sam2_model.forward(image=image, mask=resize_mask, task_type='mask', return_mask=True)
out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg)
elif mode in ['bboxtrack']:
out_mask = self.sam2_model.forward(image=image, input_box=bbox, task_type='input_box', return_mask=True)
out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg)
elif mode in ['label']:
gdino_res = self.gdino_model.forward(image, classes=label)
if 'boxes' in gdino_res and len(gdino_res['boxes']) > 0:
bboxes = gdino_res['boxes'][0]
else:
raise ValueError(f"Unable to find the corresponding boxes of label: {label}")
out_mask = self.sam2_model.forward(image=image, input_box=bboxes, task_type='input_box', return_mask=True)
out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg)
elif mode in ['caption']:
gdino_res = self.gdino_model.forward(image, caption=caption)
if 'boxes' in gdino_res and len(gdino_res['boxes']) > 0:
bboxes = gdino_res['boxes'][0]
else:
raise ValueError(f"Unable to find the corresponding boxes of caption: {caption}")
out_mask = self.sam2_model.forward(image=image, input_box=bboxes, task_type='input_box', return_mask=True)
out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg)
ret_data = {"image": out_image}
if return_mask:
ret_data["mask"] = out_mask
if return_source:
ret_data["src_image"] = image
return ret_data
class InpaintingVideoAnnotator:
def __init__(self, cfg, device=None):
self.use_aug = cfg.get('USE_AUG', True)
self.return_frame = cfg.get('RETURN_FRAME', True)
self.return_mask = cfg.get('RETURN_MASK', True)
self.return_source = cfg.get('RETURN_SOURCE', True)
self.mask_color = cfg.get('MASK_COLOR', 128)
self.mode = cfg.get('MODE', "mask")
assert self.mode in ["salient", "mask", "bbox", "salientmasktrack", "salientbboxtrack", "maskpointtrack", "maskbboxtrack", "masktrack", "bboxtrack", "label", "caption", "all"]
if self.mode in ["salient", "salienttrack"]:
from .salient import SalientAnnotator
self.salient_model = SalientAnnotator(cfg['SALIENT'], device=device)
if self.mode in ['masktrack', 'bboxtrack', 'salienttrack']:
from .sam2 import SAM2VideoAnnotator
self.sam2_model = SAM2VideoAnnotator(cfg['SAM2'], device=device)
if self.mode in ['label', 'caption']:
from .gdino import GDINOAnnotator
from .sam2 import SAM2VideoAnnotator
self.gdino_model = GDINOAnnotator(cfg['GDINO'], device=device)
self.sam2_model = SAM2VideoAnnotator(cfg['SAM2'], device=device)
if self.mode in ['all']:
from .salient import SalientAnnotator
from .gdino import GDINOAnnotator
from .sam2 import SAM2VideoAnnotator
self.salient_model = SalientAnnotator(cfg['SALIENT'], device=device)
self.gdino_model = GDINOAnnotator(cfg['GDINO'], device=device)
self.sam2_model = SAM2VideoAnnotator(cfg['SAM2'], device=device)
if self.use_aug:
from .maskaug import MaskAugAnnotator
self.maskaug_anno = MaskAugAnnotator(cfg={})
def apply_plain_mask(self, frames, mask, mask_color, return_frame=True):
out_frames = []
num_frames = len(frames)
bool_mask = mask > 0
out_masks = [np.where(bool_mask, 255, 0).astype(np.uint8)] * num_frames
if not return_frame:
return None, out_masks
for i in range(num_frames):
masked_frame = frames[i].copy()
masked_frame[bool_mask] = mask_color
out_frames.append(masked_frame)
return out_frames, out_masks
def apply_seg_mask(self, mask_data, frames, mask_color, mask_cfg=None, return_frame=True):
out_frames = []
out_masks = [(single_rle_to_mask(val[0]["mask"]) * 255).astype('uint8') for key, val in mask_data['annotations'].items()]
if not return_frame:
return None, out_masks
num_frames = min(len(out_masks), len(frames))
for i in range(num_frames):
sub_mask = out_masks[i]
if self.use_aug and mask_cfg is not None:
sub_mask = self.maskaug_anno.forward(sub_mask, mask_cfg)
out_masks[i] = sub_mask
bool_mask = sub_mask > 0
masked_frame = frames[i].copy()
masked_frame[bool_mask] = mask_color
out_frames.append(masked_frame)
out_masks = out_masks[:num_frames]
return out_frames, out_masks
def forward(self, frames=None, video=None, mask=None, bbox=None, label=None, caption=None, mode=None, return_frame=None, return_mask=None, return_source=None, mask_color=None, mask_cfg=None):
mode = mode if mode is not None else self.mode
return_frame = return_frame if return_frame is not None else self.return_frame
return_mask = return_mask if return_mask is not None else self.return_mask
return_source = return_source if return_source is not None else self.return_source
mask_color = mask_color if mask_color is not None else self.mask_color
out_frames, out_masks = [], []
if mode in ['salient']:
first_frame = frames[0] if frames is not None else read_video_one_frame(video)
mask = self.salient_model.forward(first_frame)
out_frames, out_masks = self.apply_plain_mask(frames, mask, mask_color, return_frame)
elif mode in ['mask']:
first_frame = frames[0] if frames is not None else read_video_one_frame(video)
mask_h, mask_w = mask.shape[:2]
h, w = first_frame.shape[:2]
if (mask_h ==h) and (mask_w == w):
mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
out_frames, out_masks = self.apply_plain_mask(frames, mask, mask_color, return_frame)
elif mode in ['bbox']:
first_frame = frames[0] if frames is not None else read_video_one_frame(video)
num_frames = len(frames)
x1, y1, x2, y2 = bbox
h, w = first_frame.shape[:2]
x1, y1 = int(max(0, x1)), int(max(0, y1))
x2, y2 = int(min(w, x2)), int(min(h, y2))
mask = np.zeros((h, w), dtype=np.uint8)
mask[y1:y2, x1:x2] = 255
out_masks = [mask] * num_frames
if not return_frame:
out_frames = None
else:
for i in range(num_frames):
masked_frame = frames[i].copy()
masked_frame[y1:y2, x1:x2] = mask_color
out_frames.append(masked_frame)
elif mode in ['salientmasktrack']:
first_frame = frames[0] if frames is not None else read_video_one_frame(video)
salient_mask = self.salient_model.forward(first_frame)
mask_data = self.sam2_model.forward(video=video, mask=salient_mask, task_type='mask')
out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame)
elif mode in ['salientbboxtrack']:
first_frame = frames[0] if frames is not None else read_video_one_frame(video)
salient_mask = self.salient_model.forward(first_frame)
bbox = get_mask_box(np.array(salient_mask), threshold=1)
mask_data = self.sam2_model.forward(video=video, input_box=bbox, task_type='input_box')
out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame)
elif mode in ['maskpointtrack']:
mask_data = self.sam2_model.forward(video=video, mask=mask, task_type='mask_point')
out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame)
elif mode in ['maskbboxtrack']:
mask_data = self.sam2_model.forward(video=video, mask=mask, task_type='mask_box')
out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame)
elif mode in ['masktrack']:
mask_data = self.sam2_model.forward(video=video, mask=mask, task_type='mask')
out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame)
elif mode in ['bboxtrack']:
mask_data = self.sam2_model.forward(video=video, input_box=bbox, task_type='input_box')
out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame)
elif mode in ['label']:
first_frame = frames[0] if frames is not None else read_video_one_frame(video)
gdino_res = self.gdino_model.forward(first_frame, classes=label)
if 'boxes' in gdino_res and len(gdino_res['boxes']) > 0:
bboxes = gdino_res['boxes'][0]
else:
raise ValueError(f"Unable to find the corresponding boxes of label: {label}")
mask_data = self.sam2_model.forward(video=video, input_box=bboxes, task_type='input_box')
out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame)
elif mode in ['caption']:
first_frame = frames[0] if frames is not None else read_video_one_frame(video)
gdino_res = self.gdino_model.forward(first_frame, caption=caption)
if 'boxes' in gdino_res and len(gdino_res['boxes']) > 0:
bboxes = gdino_res['boxes'][0]
else:
raise ValueError(f"Unable to find the corresponding boxes of caption: {caption}")
mask_data = self.sam2_model.forward(video=video, input_box=bboxes, task_type='input_box')
out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame)
ret_data = {}
if return_frame:
ret_data["frames"] = out_frames
if return_mask:
ret_data["masks"] = out_masks
return ret_data