|
|
|
|
|
import random |
|
import numpy as np |
|
from .utils import align_frames |
|
|
|
|
|
class FrameRefExtractAnnotator: |
|
para_dict = {} |
|
|
|
def __init__(self, cfg, device=None): |
|
|
|
self.ref_cfg = cfg.get('REF_CFG', [{"mode": "first", "proba": 0.1}, |
|
{"mode": "last", "proba": 0.1}, |
|
{"mode": "firstlast", "proba": 0.1}, |
|
{"mode": "random", "proba": 0.1}]) |
|
self.ref_num = cfg.get('REF_NUM', 1) |
|
self.ref_color = cfg.get('REF_COLOR', 127.5) |
|
self.return_dict = cfg.get('RETURN_DICT', True) |
|
self.return_mask = cfg.get('RETURN_MASK', True) |
|
|
|
|
|
def forward(self, frames, ref_cfg=None, ref_num=None, return_mask=None, return_dict=None): |
|
return_mask = return_mask if return_mask is not None else self.return_mask |
|
return_dict = return_dict if return_dict is not None else self.return_dict |
|
ref_cfg = ref_cfg if ref_cfg is not None else self.ref_cfg |
|
ref_cfg = [ref_cfg] if not isinstance(ref_cfg, list) else ref_cfg |
|
probas = [item['proba'] if 'proba' in item else 1.0 / len(ref_cfg) for item in ref_cfg] |
|
sel_ref_cfg = random.choices(ref_cfg, weights=probas, k=1)[0] |
|
mode = sel_ref_cfg['mode'] if 'mode' in sel_ref_cfg else 'original' |
|
ref_num = int(ref_num) if ref_num is not None else self.ref_num |
|
|
|
frame_num = len(frames) |
|
frame_num_range = list(range(frame_num)) |
|
if mode == "first": |
|
sel_idx = frame_num_range[:ref_num] |
|
elif mode == "last": |
|
sel_idx = frame_num_range[-ref_num:] |
|
elif mode == "firstlast": |
|
sel_idx = frame_num_range[:ref_num] + frame_num_range[-ref_num:] |
|
elif mode == "random": |
|
sel_idx = random.sample(frame_num_range, ref_num) |
|
else: |
|
raise NotImplementedError |
|
|
|
out_frames, out_masks = [], [] |
|
for i in range(frame_num): |
|
if i in sel_idx: |
|
out_frame = frames[i] |
|
out_mask = np.zeros_like(frames[i][:, :, 0]) |
|
else: |
|
out_frame = np.ones_like(frames[i]) * self.ref_color |
|
out_mask = np.ones_like(frames[i][:, :, 0]) * 255 |
|
out_frames.append(out_frame) |
|
out_masks.append(out_mask) |
|
|
|
if return_dict: |
|
ret_data = {"frames": out_frames} |
|
if return_mask: |
|
ret_data['masks'] = out_masks |
|
return ret_data |
|
else: |
|
if return_mask: |
|
return out_frames, out_masks |
|
else: |
|
return out_frames |
|
|
|
|
|
|
|
class FrameRefExpandAnnotator: |
|
para_dict = {} |
|
|
|
def __init__(self, cfg, device=None): |
|
|
|
self.ref_color = cfg.get('REF_COLOR', 127.5) |
|
self.return_mask = cfg.get('RETURN_MASK', True) |
|
self.return_dict = cfg.get('RETURN_DICT', True) |
|
self.mode = cfg.get('MODE', "firstframe") |
|
assert self.mode in ["firstframe", "lastframe", "firstlastframe", "firstclip", "lastclip", "firstlastclip", "all"] |
|
|
|
def forward(self, image=None, image_2=None, frames=None, frames_2=None, mode=None, expand_num=None, return_mask=None, return_dict=None): |
|
mode = mode if mode is not None else self.mode |
|
return_mask = return_mask if return_mask is not None else self.return_mask |
|
return_dict = return_dict if return_dict is not None else self.return_dict |
|
|
|
if 'frame' in mode: |
|
frames = [image] if image is not None and not isinstance(frames, list) else image |
|
frames_2 = [image_2] if image_2 is not None and not isinstance(image_2, list) else image_2 |
|
|
|
expand_frames = [np.ones_like(frames[0]) * self.ref_color] * expand_num |
|
expand_masks = [np.ones_like(frames[0][:, :, 0]) * 255] * expand_num |
|
source_frames = frames |
|
source_masks = [np.zeros_like(frames[0][:, :, 0])] * len(frames) |
|
|
|
if mode in ["firstframe", "firstclip"]: |
|
out_frames = source_frames + expand_frames |
|
out_masks = source_masks + expand_masks |
|
elif mode in ["lastframe", "lastclip"]: |
|
out_frames = expand_frames + source_frames |
|
out_masks = expand_masks + source_masks |
|
elif mode in ["firstlastframe", "firstlastclip"]: |
|
source_frames_2 = [align_frames(source_frames[0], f2) for f2 in frames_2] |
|
source_masks_2 = [np.zeros_like(source_frames_2[0][:, :, 0])] * len(frames_2) |
|
out_frames = source_frames + expand_frames + source_frames_2 |
|
out_masks = source_masks + expand_masks + source_masks_2 |
|
else: |
|
raise NotImplementedError |
|
|
|
if return_dict: |
|
ret_data = {"frames": out_frames} |
|
if return_mask: |
|
ret_data['masks'] = out_masks |
|
return ret_data |
|
else: |
|
if return_mask: |
|
return out_frames, out_masks |
|
else: |
|
return out_frames |
|
|