Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
import torch | |
def make_one_hot(cfg, is_inference, data): | |
r"""Convert appropriate image data types to one-hot representation. | |
Args: | |
data (dict): Dict containing data_type as key, with each value | |
as a list of torch.Tensors. | |
Returns: | |
data (dict): same as input data, but with one-hot for selected | |
types. | |
""" | |
assert hasattr(cfg, 'one_hot_num_classes') | |
num_classes = getattr(cfg, 'one_hot_num_classes') | |
use_dont_care = getattr(cfg, 'use_dont_care', False) | |
for data_type, data_type_num_classes in num_classes.items(): | |
if data_type in data.keys(): | |
data[data_type] = _encode_onehot(data[data_type] * 255.0, data_type_num_classes, use_dont_care).float() | |
return data | |
def concat_labels(cfg, is_inference, data): | |
assert hasattr(cfg, 'input_labels') | |
input_labels = getattr(cfg, 'input_labels') | |
dataset_type = getattr(cfg, 'type') | |
# Package output. | |
labels = [] | |
for data_type in input_labels: | |
label = data.pop(data_type) | |
labels.append(label) | |
if not ('video' in dataset_type): | |
data['label'] = torch.cat(labels, dim=0) | |
else: | |
data['label'] = torch.cat(labels, dim=1) | |
return data | |
def concat_few_shot_labels(cfg, is_inference, data): | |
assert hasattr(cfg, 'input_few_shot_labels') | |
input_labels = getattr(cfg, 'input_few_shot_labels') | |
dataset_type = getattr(cfg, 'type') | |
# Package output. | |
labels = [] | |
for data_type in input_labels: | |
label = data.pop(data_type) | |
labels.append(label) | |
if not ('video' in dataset_type): | |
data['few_shot_label'] = torch.cat(labels, dim=0) | |
else: | |
data['few_shot_label'] = torch.cat(labels, dim=1) | |
return data | |
def move_dont_care(cfg, is_inference, data): | |
assert hasattr(cfg, 'move_dont_care') | |
move_dont_care = getattr(cfg, 'move_dont_care') | |
for data_type, data_type_num_classes in move_dont_care.items(): | |
label_map = data[data_type] * 255.0 | |
label_map[label_map < 0] = data_type_num_classes | |
label_map[label_map >= data_type_num_classes] = data_type_num_classes | |
data[data_type] = label_map / 255.0 | |
return data | |
def _encode_onehot(label_map, num_classes, use_dont_care): | |
r"""Make input one-hot. | |
Args: | |
label_map (torch.Tensor): (C, H, W) tensor containing indices. | |
num_classes (int): Number of labels to expand tensor to. | |
use_dont_care (bool): Use the dont care label or not? | |
Returns: | |
output (torch.Tensor): (num_classes, H, W) one-hot tensor. | |
""" | |
# All labels lie in [0. num_classes - 1]. | |
# Encode dont care as num_classes. | |
label_map[label_map < 0] = num_classes | |
label_map[label_map >= num_classes] = num_classes | |
size = label_map.size() | |
output_size = (num_classes + 1, size[-2], size[-1]) | |
output = torch.zeros(*output_size) | |
if label_map.dim() == 4: | |
output = output.unsqueeze(0).repeat(label_map.size(0), 1, 1, 1) | |
output = output.scatter_(1, label_map.data.long(), 1.0) | |
if not use_dont_care: | |
output = output[:, :num_classes, ...] | |
else: | |
output = output.scatter_(0, label_map.data.long(), 1.0) | |
if not use_dont_care: | |
output = output[:num_classes, ...] | |
return output | |