File size: 3,504 Bytes
f670afc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import torch


def make_one_hot(cfg, is_inference, data):
    r"""Convert appropriate image data types to one-hot representation.

    Args:
        data (dict): Dict containing data_type as key, with each value
            as a list of torch.Tensors.
    Returns:
        data (dict): same as input data, but with one-hot for selected
        types.
    """
    assert hasattr(cfg, 'one_hot_num_classes')
    num_classes = getattr(cfg, 'one_hot_num_classes')
    use_dont_care = getattr(cfg, 'use_dont_care', False)
    for data_type, data_type_num_classes in num_classes.items():
        if data_type in data.keys():
            data[data_type] = _encode_onehot(data[data_type] * 255.0, data_type_num_classes, use_dont_care).float()
    return data


def concat_labels(cfg, is_inference, data):
    assert hasattr(cfg, 'input_labels')
    input_labels = getattr(cfg, 'input_labels')
    dataset_type = getattr(cfg, 'type')

    # Package output.
    labels = []
    for data_type in input_labels:
        label = data.pop(data_type)
        labels.append(label)
    if not ('video' in dataset_type):
        data['label'] = torch.cat(labels, dim=0)
    else:
        data['label'] = torch.cat(labels, dim=1)
    return data


def concat_few_shot_labels(cfg, is_inference, data):
    assert hasattr(cfg, 'input_few_shot_labels')
    input_labels = getattr(cfg, 'input_few_shot_labels')
    dataset_type = getattr(cfg, 'type')

    # Package output.
    labels = []
    for data_type in input_labels:
        label = data.pop(data_type)
        labels.append(label)
    if not ('video' in dataset_type):
        data['few_shot_label'] = torch.cat(labels, dim=0)
    else:
        data['few_shot_label'] = torch.cat(labels, dim=1)
    return data


def move_dont_care(cfg, is_inference, data):
    assert hasattr(cfg, 'move_dont_care')
    move_dont_care = getattr(cfg, 'move_dont_care')
    for data_type, data_type_num_classes in move_dont_care.items():
        label_map = data[data_type] * 255.0
        label_map[label_map < 0] = data_type_num_classes
        label_map[label_map >= data_type_num_classes] = data_type_num_classes
        data[data_type] = label_map / 255.0
    return data


def _encode_onehot(label_map, num_classes, use_dont_care):
    r"""Make input one-hot.

    Args:
        label_map (torch.Tensor): (C, H, W) tensor containing indices.
        num_classes (int): Number of labels to expand tensor to.
        use_dont_care (bool): Use the dont care label or not?
    Returns:
        output (torch.Tensor): (num_classes, H, W) one-hot tensor.
    """
    # All labels lie in [0. num_classes - 1].
    # Encode dont care as num_classes.
    label_map[label_map < 0] = num_classes
    label_map[label_map >= num_classes] = num_classes

    size = label_map.size()
    output_size = (num_classes + 1, size[-2], size[-1])
    output = torch.zeros(*output_size)
    if label_map.dim() == 4:
        output = output.unsqueeze(0).repeat(label_map.size(0), 1, 1, 1)
        output = output.scatter_(1, label_map.data.long(), 1.0)
        if not use_dont_care:
            output = output[:, :num_classes, ...]
    else:
        output = output.scatter_(0, label_map.data.long(), 1.0)
        if not use_dont_care:
            output = output[:num_classes, ...]
    return output