File size: 5,127 Bytes
690f890
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import random
import numpy as np
from .utils import align_frames


class FrameRefExtractAnnotator:
    para_dict = {}

    def __init__(self, cfg, device=None):
        # first / last / firstlast / random
        self.ref_cfg = cfg.get('REF_CFG', [{"mode": "first", "proba": 0.1},
                                           {"mode": "last", "proba": 0.1},
                                           {"mode": "firstlast", "proba": 0.1},
                                           {"mode": "random", "proba": 0.1}])
        self.ref_num = cfg.get('REF_NUM', 1)
        self.ref_color = cfg.get('REF_COLOR', 127.5)
        self.return_dict = cfg.get('RETURN_DICT', True)
        self.return_mask = cfg.get('RETURN_MASK', True)


    def forward(self, frames, ref_cfg=None, ref_num=None, return_mask=None, return_dict=None):
        return_mask = return_mask if return_mask is not None else self.return_mask
        return_dict = return_dict if return_dict is not None else self.return_dict
        ref_cfg = ref_cfg if ref_cfg is not None else self.ref_cfg
        ref_cfg = [ref_cfg] if not isinstance(ref_cfg, list) else ref_cfg
        probas = [item['proba'] if 'proba' in item else 1.0 / len(ref_cfg) for item in ref_cfg]
        sel_ref_cfg = random.choices(ref_cfg, weights=probas, k=1)[0]
        mode = sel_ref_cfg['mode'] if 'mode' in sel_ref_cfg else 'original'
        ref_num = int(ref_num) if ref_num is not None else self.ref_num

        frame_num = len(frames)
        frame_num_range = list(range(frame_num))
        if mode == "first":
            sel_idx = frame_num_range[:ref_num]
        elif mode == "last":
            sel_idx = frame_num_range[-ref_num:]
        elif mode == "firstlast":
            sel_idx = frame_num_range[:ref_num] + frame_num_range[-ref_num:]
        elif mode == "random":
            sel_idx = random.sample(frame_num_range, ref_num)
        else:
            raise NotImplementedError

        out_frames, out_masks = [], []
        for i in range(frame_num):
            if i in sel_idx:
                out_frame = frames[i]
                out_mask = np.zeros_like(frames[i][:, :, 0])
            else:
                out_frame = np.ones_like(frames[i]) * self.ref_color
                out_mask = np.ones_like(frames[i][:, :, 0]) * 255
            out_frames.append(out_frame)
            out_masks.append(out_mask)

        if return_dict:
            ret_data = {"frames": out_frames}
            if return_mask:
                ret_data['masks'] = out_masks
            return ret_data
        else:
            if return_mask:
                return out_frames, out_masks
            else:
                return out_frames



class FrameRefExpandAnnotator:
    para_dict = {}

    def __init__(self, cfg, device=None):
        # first / last / firstlast
        self.ref_color = cfg.get('REF_COLOR', 127.5)
        self.return_mask = cfg.get('RETURN_MASK', True)
        self.return_dict = cfg.get('RETURN_DICT', True)
        self.mode = cfg.get('MODE', "firstframe")
        assert self.mode in ["firstframe", "lastframe", "firstlastframe", "firstclip", "lastclip", "firstlastclip", "all"]

    def forward(self, image=None, image_2=None, frames=None, frames_2=None, mode=None, expand_num=None, return_mask=None, return_dict=None):
        mode = mode if mode is not None else self.mode
        return_mask = return_mask if return_mask is not None else self.return_mask
        return_dict = return_dict if return_dict is not None else self.return_dict

        if 'frame' in mode:
            frames = [image] if image is not None and not isinstance(frames, list) else image
            frames_2 = [image_2] if image_2 is not None and not isinstance(image_2, list) else image_2

        expand_frames = [np.ones_like(frames[0]) * self.ref_color] * expand_num
        expand_masks = [np.ones_like(frames[0][:, :, 0]) * 255] * expand_num
        source_frames = frames
        source_masks = [np.zeros_like(frames[0][:, :, 0])] * len(frames)

        if mode in ["firstframe", "firstclip"]:
            out_frames = source_frames + expand_frames
            out_masks = source_masks + expand_masks
        elif mode in ["lastframe", "lastclip"]:
            out_frames = expand_frames + source_frames
            out_masks = expand_masks + source_masks
        elif mode in ["firstlastframe", "firstlastclip"]:
            source_frames_2 = [align_frames(source_frames[0], f2) for f2 in frames_2]
            source_masks_2 = [np.zeros_like(source_frames_2[0][:, :, 0])] * len(frames_2)
            out_frames = source_frames + expand_frames + source_frames_2
            out_masks = source_masks + expand_masks + source_masks_2
        else:
            raise NotImplementedError

        if return_dict:
            ret_data = {"frames": out_frames}
            if return_mask:
                ret_data['masks'] = out_masks
            return ret_data
        else:
            if return_mask:
                return out_frames, out_masks
            else:
                return out_frames