# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import cv2 import math import random from abc import ABCMeta import numpy as np import torch from PIL import Image, ImageDraw from .utils import convert_to_numpy, convert_to_pil, single_rle_to_mask, get_mask_box, read_video_one_frame class InpaintingAnnotator: def __init__(self, cfg, device=None): self.use_aug = cfg.get('USE_AUG', True) self.return_mask = cfg.get('RETURN_MASK', True) self.return_source = cfg.get('RETURN_SOURCE', True) self.mask_color = cfg.get('MASK_COLOR', 128) self.mode = cfg.get('MODE', "mask") assert self.mode in ["salient", "mask", "bbox", "salientmasktrack", "salientbboxtrack", "maskpointtrack", "maskbboxtrack", "masktrack", "bboxtrack", "label", "caption", "all"] if self.mode in ["salient", "salienttrack"]: from .salient import SalientAnnotator self.salient_model = SalientAnnotator(cfg['SALIENT'], device=device) if self.mode in ['masktrack', 'bboxtrack', 'salienttrack']: from .sam2 import SAM2ImageAnnotator self.sam2_model = SAM2ImageAnnotator(cfg['SAM2'], device=device) if self.mode in ['label', 'caption']: from .gdino import GDINOAnnotator from .sam2 import SAM2ImageAnnotator self.gdino_model = GDINOAnnotator(cfg['GDINO'], device=device) self.sam2_model = SAM2ImageAnnotator(cfg['SAM2'], device=device) if self.mode in ['all']: from .salient import SalientAnnotator from .gdino import GDINOAnnotator from .sam2 import SAM2ImageAnnotator self.salient_model = SalientAnnotator(cfg['SALIENT'], device=device) self.gdino_model = GDINOAnnotator(cfg['GDINO'], device=device) self.sam2_model = SAM2ImageAnnotator(cfg['SAM2'], device=device) if self.use_aug: from .maskaug import MaskAugAnnotator self.maskaug_anno = MaskAugAnnotator(cfg={}) def apply_plain_mask(self, image, mask, mask_color): bool_mask = mask > 0 out_image = image.copy() out_image[bool_mask] = mask_color out_mask = np.where(bool_mask, 255, 0).astype(np.uint8) return out_image, out_mask def apply_seg_mask(self, image, mask, mask_color, mask_cfg=None): out_mask = (mask * 255).astype('uint8') if self.use_aug and mask_cfg is not None: out_mask = self.maskaug_anno.forward(out_mask, mask_cfg) bool_mask = out_mask > 0 out_image = image.copy() out_image[bool_mask] = mask_color return out_image, out_mask def forward(self, image=None, mask=None, bbox=None, label=None, caption=None, mode=None, return_mask=None, return_source=None, mask_color=None, mask_cfg=None): mode = mode if mode is not None else self.mode return_mask = return_mask if return_mask is not None else self.return_mask return_source = return_source if return_source is not None else self.return_source mask_color = mask_color if mask_color is not None else self.mask_color image = convert_to_numpy(image) out_image, out_mask = None, None if mode in ['salient']: mask = self.salient_model.forward(image) out_image, out_mask = self.apply_plain_mask(image, mask, mask_color) elif mode in ['mask']: mask_h, mask_w = mask.shape[:2] h, w = image.shape[:2] if (mask_h ==h) and (mask_w == w): mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST) out_image, out_mask = self.apply_plain_mask(image, mask, mask_color) elif mode in ['bbox']: x1, y1, x2, y2 = bbox h, w = image.shape[:2] x1, y1 = int(max(0, x1)), int(max(0, y1)) x2, y2 = int(min(w, x2)), int(min(h, y2)) out_image = image.copy() out_image[y1:y2, x1:x2] = mask_color out_mask = np.zeros((h, w), dtype=np.uint8) out_mask[y1:y2, x1:x2] = 255 elif mode in ['salientmasktrack']: mask = self.salient_model.forward(image) resize_mask = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_NEAREST) out_mask = self.sam2_model.forward(image=image, mask=resize_mask, task_type='mask', return_mask=True) out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg) elif mode in ['salientbboxtrack']: mask = self.salient_model.forward(image) bbox = get_mask_box(np.array(mask), threshold=1) out_mask = self.sam2_model.forward(image=image, input_box=bbox, task_type='input_box', return_mask=True) out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg) elif mode in ['maskpointtrack']: out_mask = self.sam2_model.forward(image=image, mask=mask, task_type='mask_point', return_mask=True) out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg) elif mode in ['maskbboxtrack']: out_mask = self.sam2_model.forward(image=image, mask=mask, task_type='mask_box', return_mask=True) out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg) elif mode in ['masktrack']: resize_mask = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_NEAREST) out_mask = self.sam2_model.forward(image=image, mask=resize_mask, task_type='mask', return_mask=True) out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg) elif mode in ['bboxtrack']: out_mask = self.sam2_model.forward(image=image, input_box=bbox, task_type='input_box', return_mask=True) out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg) elif mode in ['label']: gdino_res = self.gdino_model.forward(image, classes=label) if 'boxes' in gdino_res and len(gdino_res['boxes']) > 0: bboxes = gdino_res['boxes'][0] else: raise ValueError(f"Unable to find the corresponding boxes of label: {label}") out_mask = self.sam2_model.forward(image=image, input_box=bboxes, task_type='input_box', return_mask=True) out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg) elif mode in ['caption']: gdino_res = self.gdino_model.forward(image, caption=caption) if 'boxes' in gdino_res and len(gdino_res['boxes']) > 0: bboxes = gdino_res['boxes'][0] else: raise ValueError(f"Unable to find the corresponding boxes of caption: {caption}") out_mask = self.sam2_model.forward(image=image, input_box=bboxes, task_type='input_box', return_mask=True) out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg) ret_data = {"image": out_image} if return_mask: ret_data["mask"] = out_mask if return_source: ret_data["src_image"] = image return ret_data class InpaintingVideoAnnotator: def __init__(self, cfg, device=None): self.use_aug = cfg.get('USE_AUG', True) self.return_frame = cfg.get('RETURN_FRAME', True) self.return_mask = cfg.get('RETURN_MASK', True) self.return_source = cfg.get('RETURN_SOURCE', True) self.mask_color = cfg.get('MASK_COLOR', 128) self.mode = cfg.get('MODE', "mask") assert self.mode in ["salient", "mask", "bbox", "salientmasktrack", "salientbboxtrack", "maskpointtrack", "maskbboxtrack", "masktrack", "bboxtrack", "label", "caption", "all"] if self.mode in ["salient", "salienttrack"]: from .salient import SalientAnnotator self.salient_model = SalientAnnotator(cfg['SALIENT'], device=device) if self.mode in ['masktrack', 'bboxtrack', 'salienttrack']: from .sam2 import SAM2VideoAnnotator self.sam2_model = SAM2VideoAnnotator(cfg['SAM2'], device=device) if self.mode in ['label', 'caption']: from .gdino import GDINOAnnotator from .sam2 import SAM2VideoAnnotator self.gdino_model = GDINOAnnotator(cfg['GDINO'], device=device) self.sam2_model = SAM2VideoAnnotator(cfg['SAM2'], device=device) if self.mode in ['all']: from .salient import SalientAnnotator from .gdino import GDINOAnnotator from .sam2 import SAM2VideoAnnotator self.salient_model = SalientAnnotator(cfg['SALIENT'], device=device) self.gdino_model = GDINOAnnotator(cfg['GDINO'], device=device) self.sam2_model = SAM2VideoAnnotator(cfg['SAM2'], device=device) if self.use_aug: from .maskaug import MaskAugAnnotator self.maskaug_anno = MaskAugAnnotator(cfg={}) def apply_plain_mask(self, frames, mask, mask_color, return_frame=True): out_frames = [] num_frames = len(frames) bool_mask = mask > 0 out_masks = [np.where(bool_mask, 255, 0).astype(np.uint8)] * num_frames if not return_frame: return None, out_masks for i in range(num_frames): masked_frame = frames[i].copy() masked_frame[bool_mask] = mask_color out_frames.append(masked_frame) return out_frames, out_masks def apply_seg_mask(self, mask_data, frames, mask_color, mask_cfg=None, return_frame=True): out_frames = [] out_masks = [(single_rle_to_mask(val[0]["mask"]) * 255).astype('uint8') for key, val in mask_data['annotations'].items()] if not return_frame: return None, out_masks num_frames = min(len(out_masks), len(frames)) for i in range(num_frames): sub_mask = out_masks[i] if self.use_aug and mask_cfg is not None: sub_mask = self.maskaug_anno.forward(sub_mask, mask_cfg) out_masks[i] = sub_mask bool_mask = sub_mask > 0 masked_frame = frames[i].copy() masked_frame[bool_mask] = mask_color out_frames.append(masked_frame) out_masks = out_masks[:num_frames] return out_frames, out_masks def forward(self, frames=None, video=None, mask=None, bbox=None, label=None, caption=None, mode=None, return_frame=None, return_mask=None, return_source=None, mask_color=None, mask_cfg=None): mode = mode if mode is not None else self.mode return_frame = return_frame if return_frame is not None else self.return_frame return_mask = return_mask if return_mask is not None else self.return_mask return_source = return_source if return_source is not None else self.return_source mask_color = mask_color if mask_color is not None else self.mask_color out_frames, out_masks = [], [] if mode in ['salient']: first_frame = frames[0] if frames is not None else read_video_one_frame(video) mask = self.salient_model.forward(first_frame) out_frames, out_masks = self.apply_plain_mask(frames, mask, mask_color, return_frame) elif mode in ['mask']: first_frame = frames[0] if frames is not None else read_video_one_frame(video) mask_h, mask_w = mask.shape[:2] h, w = first_frame.shape[:2] if (mask_h ==h) and (mask_w == w): mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST) out_frames, out_masks = self.apply_plain_mask(frames, mask, mask_color, return_frame) elif mode in ['bbox']: first_frame = frames[0] if frames is not None else read_video_one_frame(video) num_frames = len(frames) x1, y1, x2, y2 = bbox h, w = first_frame.shape[:2] x1, y1 = int(max(0, x1)), int(max(0, y1)) x2, y2 = int(min(w, x2)), int(min(h, y2)) mask = np.zeros((h, w), dtype=np.uint8) mask[y1:y2, x1:x2] = 255 out_masks = [mask] * num_frames if not return_frame: out_frames = None else: for i in range(num_frames): masked_frame = frames[i].copy() masked_frame[y1:y2, x1:x2] = mask_color out_frames.append(masked_frame) elif mode in ['salientmasktrack']: first_frame = frames[0] if frames is not None else read_video_one_frame(video) salient_mask = self.salient_model.forward(first_frame) mask_data = self.sam2_model.forward(video=video, mask=salient_mask, task_type='mask') out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame) elif mode in ['salientbboxtrack']: first_frame = frames[0] if frames is not None else read_video_one_frame(video) salient_mask = self.salient_model.forward(first_frame) bbox = get_mask_box(np.array(salient_mask), threshold=1) mask_data = self.sam2_model.forward(video=video, input_box=bbox, task_type='input_box') out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame) elif mode in ['maskpointtrack']: mask_data = self.sam2_model.forward(video=video, mask=mask, task_type='mask_point') out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame) elif mode in ['maskbboxtrack']: mask_data = self.sam2_model.forward(video=video, mask=mask, task_type='mask_box') out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame) elif mode in ['masktrack']: mask_data = self.sam2_model.forward(video=video, mask=mask, task_type='mask') out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame) elif mode in ['bboxtrack']: mask_data = self.sam2_model.forward(video=video, input_box=bbox, task_type='input_box') out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame) elif mode in ['label']: first_frame = frames[0] if frames is not None else read_video_one_frame(video) gdino_res = self.gdino_model.forward(first_frame, classes=label) if 'boxes' in gdino_res and len(gdino_res['boxes']) > 0: bboxes = gdino_res['boxes'][0] else: raise ValueError(f"Unable to find the corresponding boxes of label: {label}") mask_data = self.sam2_model.forward(video=video, input_box=bboxes, task_type='input_box') out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame) elif mode in ['caption']: first_frame = frames[0] if frames is not None else read_video_one_frame(video) gdino_res = self.gdino_model.forward(first_frame, caption=caption) if 'boxes' in gdino_res and len(gdino_res['boxes']) > 0: bboxes = gdino_res['boxes'][0] else: raise ValueError(f"Unable to find the corresponding boxes of caption: {caption}") mask_data = self.sam2_model.forward(video=video, input_box=bboxes, task_type='input_box') out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame) ret_data = {} if return_frame: ret_data["frames"] = out_frames if return_mask: ret_data["masks"] = out_masks return ret_data