vace-demo / vace /annotators /frameref.py
maffia's picture
Upload 94 files
690f890 verified
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import random
import numpy as np
from .utils import align_frames
class FrameRefExtractAnnotator:
para_dict = {}
def __init__(self, cfg, device=None):
# first / last / firstlast / random
self.ref_cfg = cfg.get('REF_CFG', [{"mode": "first", "proba": 0.1},
{"mode": "last", "proba": 0.1},
{"mode": "firstlast", "proba": 0.1},
{"mode": "random", "proba": 0.1}])
self.ref_num = cfg.get('REF_NUM', 1)
self.ref_color = cfg.get('REF_COLOR', 127.5)
self.return_dict = cfg.get('RETURN_DICT', True)
self.return_mask = cfg.get('RETURN_MASK', True)
def forward(self, frames, ref_cfg=None, ref_num=None, return_mask=None, return_dict=None):
return_mask = return_mask if return_mask is not None else self.return_mask
return_dict = return_dict if return_dict is not None else self.return_dict
ref_cfg = ref_cfg if ref_cfg is not None else self.ref_cfg
ref_cfg = [ref_cfg] if not isinstance(ref_cfg, list) else ref_cfg
probas = [item['proba'] if 'proba' in item else 1.0 / len(ref_cfg) for item in ref_cfg]
sel_ref_cfg = random.choices(ref_cfg, weights=probas, k=1)[0]
mode = sel_ref_cfg['mode'] if 'mode' in sel_ref_cfg else 'original'
ref_num = int(ref_num) if ref_num is not None else self.ref_num
frame_num = len(frames)
frame_num_range = list(range(frame_num))
if mode == "first":
sel_idx = frame_num_range[:ref_num]
elif mode == "last":
sel_idx = frame_num_range[-ref_num:]
elif mode == "firstlast":
sel_idx = frame_num_range[:ref_num] + frame_num_range[-ref_num:]
elif mode == "random":
sel_idx = random.sample(frame_num_range, ref_num)
else:
raise NotImplementedError
out_frames, out_masks = [], []
for i in range(frame_num):
if i in sel_idx:
out_frame = frames[i]
out_mask = np.zeros_like(frames[i][:, :, 0])
else:
out_frame = np.ones_like(frames[i]) * self.ref_color
out_mask = np.ones_like(frames[i][:, :, 0]) * 255
out_frames.append(out_frame)
out_masks.append(out_mask)
if return_dict:
ret_data = {"frames": out_frames}
if return_mask:
ret_data['masks'] = out_masks
return ret_data
else:
if return_mask:
return out_frames, out_masks
else:
return out_frames
class FrameRefExpandAnnotator:
para_dict = {}
def __init__(self, cfg, device=None):
# first / last / firstlast
self.ref_color = cfg.get('REF_COLOR', 127.5)
self.return_mask = cfg.get('RETURN_MASK', True)
self.return_dict = cfg.get('RETURN_DICT', True)
self.mode = cfg.get('MODE', "firstframe")
assert self.mode in ["firstframe", "lastframe", "firstlastframe", "firstclip", "lastclip", "firstlastclip", "all"]
def forward(self, image=None, image_2=None, frames=None, frames_2=None, mode=None, expand_num=None, return_mask=None, return_dict=None):
mode = mode if mode is not None else self.mode
return_mask = return_mask if return_mask is not None else self.return_mask
return_dict = return_dict if return_dict is not None else self.return_dict
if 'frame' in mode:
frames = [image] if image is not None and not isinstance(frames, list) else image
frames_2 = [image_2] if image_2 is not None and not isinstance(image_2, list) else image_2
expand_frames = [np.ones_like(frames[0]) * self.ref_color] * expand_num
expand_masks = [np.ones_like(frames[0][:, :, 0]) * 255] * expand_num
source_frames = frames
source_masks = [np.zeros_like(frames[0][:, :, 0])] * len(frames)
if mode in ["firstframe", "firstclip"]:
out_frames = source_frames + expand_frames
out_masks = source_masks + expand_masks
elif mode in ["lastframe", "lastclip"]:
out_frames = expand_frames + source_frames
out_masks = expand_masks + source_masks
elif mode in ["firstlastframe", "firstlastclip"]:
source_frames_2 = [align_frames(source_frames[0], f2) for f2 in frames_2]
source_masks_2 = [np.zeros_like(source_frames_2[0][:, :, 0])] * len(frames_2)
out_frames = source_frames + expand_frames + source_frames_2
out_masks = source_masks + expand_masks + source_masks_2
else:
raise NotImplementedError
if return_dict:
ret_data = {"frames": out_frames}
if return_mask:
ret_data['masks'] = out_masks
return ret_data
else:
if return_mask:
return out_frames, out_masks
else:
return out_frames