Spaces:
Runtime error
Runtime error
File size: 8,725 Bytes
f670afc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 |
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import random
import tempfile
from collections import OrderedDict
import warnings
import numpy as np
import torch
# import torchvision.io as io
import cv2
from PIL import Image
from imaginaire.datasets.base import BaseDataset
class Dataset(BaseDataset):
r"""Dataset for paired few shot videos.
Args:
cfg (Config): Loaded config object.
is_inference (bool): In train or inference mode?
"""
def __init__(self, cfg, is_inference=False, is_test=False):
self.paired = True
super(Dataset, self).__init__(cfg, is_inference, is_test)
self.is_video_dataset = True
self.few_shot_K = 1
self.first_last_only = getattr(cfg.data, 'first_last_only', False)
self.sample_far_frames_more = getattr(cfg.data, 'sample_far_frames_more', False)
def get_label_lengths(self):
r"""Get num channels of all labels to be concated.
Returns:
label_lengths (OrderedDict): Dict mapping image data_type to num
channels.
"""
label_lengths = OrderedDict()
for data_type in self.input_labels:
data_cfg = self.cfgdata
if hasattr(data_cfg, 'one_hot_num_classes') and \
data_type in data_cfg.one_hot_num_classes:
label_lengths[data_type] = data_cfg.one_hot_num_classes[data_type]
if getattr(data_cfg, 'use_dont_care', False):
label_lengths[data_type] += 1
else:
label_lengths[data_type] = self.num_channels[data_type]
return label_lengths
def num_inference_sequences(self):
r"""Number of sequences available for inference.
Returns:
(int)
"""
assert self.is_inference
return len(self.mapping)
def _create_mapping(self):
r"""Creates mapping from idx to key in LMDB.
Returns:
(tuple):
- self.mapping (dict): Dict of seq_len to list of sequences.
- self.epoch_length (int): Number of samples in an epoch.
"""
# Create dict mapping length to sequence.
mapping = []
for lmdb_idx, sequence_list in enumerate(self.sequence_lists):
for sequence_name, filenames in sequence_list.items():
for filename in filenames:
# This file is corrupt.
if filename == 'z-KziTO_5so_0019_start0_end85_h596_w596':
continue
mapping.append({
'lmdb_root': self.lmdb_roots[lmdb_idx],
'lmdb_idx': lmdb_idx,
'sequence_name': sequence_name,
'filenames': [filename],
})
self.mapping = mapping
self.epoch_length = len(mapping)
return self.mapping, self.epoch_length
def _sample_keys(self, index):
r"""Gets files to load for this sample.
Args:
index (int): Index in [0, len(dataset)].
Returns:
(tuple):
- key (dict):
- lmdb_idx (int): Chosen LMDB dataset root.
- sequence_name (str): Chosen sequence in chosen dataset.
- filenames (list of str): Chosen filenames in chosen sequence.
"""
if self.is_inference:
assert index < self.epoch_length
raise NotImplementedError
else:
# Select a video at random.
key = random.choice(self.mapping)
return key
def _create_sequence_keys(self, sequence_name, filenames):
r"""Create the LMDB key for this piece of information.
Args:
sequence_name (str): Which sequence from the chosen dataset.
filenames (list of str): List of filenames in this sequence.
Returns:
keys (list): List of full keys.
"""
assert isinstance(filenames, list), 'Filenames should be a list.'
keys = []
for filename in filenames:
keys.append('%s/%s' % (sequence_name, filename))
return keys
def _getitem(self, index):
r"""Gets selected files.
Args:
index (int): Index into dataset.
concat (bool): Concatenate all items in labels?
Returns:
data (dict): Dict with all chosen data_types.
"""
# Select a sample from the available data.
keys = self._sample_keys(index)
# Unpack keys.
lmdb_idx = keys['lmdb_idx']
sequence_name = keys['sequence_name']
filenames = keys['filenames']
# Get key and lmdbs.
keys, lmdbs = {}, {}
for data_type in self.dataset_data_types:
keys[data_type] = self._create_sequence_keys(
sequence_name, filenames)
lmdbs[data_type] = self.lmdbs[data_type][lmdb_idx]
# Load all data for this index.
data = self.load_from_dataset(keys, lmdbs)
# Get frames from video.
try:
temp = tempfile.NamedTemporaryFile()
temp.write(data['videos'][0])
temp.seek(0)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# frames, _, info = io.read_video(temp)
# num_frames = frames.size(0)
cap = cv2.VideoCapture(temp.name)
num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if self.first_last_only:
chosen_idxs = [0, num_frames - 1]
else:
# chosen_idxs = random.sample(range(frames.size(0)), 2)
chosen_idx = random.sample(range(num_frames), 1)[0]
few_shot_choose_range = list(range(chosen_idx)) + list(range(chosen_idx + 1, num_frames))
if self.sample_far_frames_more:
choose_weight = list(reversed(range(chosen_idx))) + list(range(num_frames - chosen_idx - 1))
few_shot_idx = random.choices(few_shot_choose_range, choose_weight, k=self.few_shot_K)
else:
few_shot_idx = random.sample(few_shot_choose_range, k=self.few_shot_K)
chosen_idxs = few_shot_idx + [chosen_idx]
chosen_images = []
for idx in chosen_idxs:
# chosen_images.append(Image.fromarray(frames[idx].numpy()))
cap.set(1, idx)
_, frame = cap.read()
chosen_images.append(Image.fromarray(frame[:, :, ::-1]))
except Exception:
print('Issue with file:', sequence_name, filenames)
blank = np.zeros((512, 512, 3), dtype=np.uint8)
chosen_images = [Image.fromarray(blank), Image.fromarray(blank)]
data['videos'] = chosen_images
# Apply ops pre augmentation.
data = self.apply_ops(data, self.pre_aug_ops)
# Do augmentations for images.
data, is_flipped = self.perform_augmentation(
data, paired=True, augment_ops=self.augmentor.augment_ops)
# Individual video frame augmentation is used in face-vid2vid.
data = self.perform_individual_video_frame(
data, self.augmentor.individual_video_frame_augmentation_ops)
# Apply ops post augmentation.
data = self.apply_ops(data, self.post_aug_ops)
# Convert images to tensor.
data = self.to_tensor(data)
# Pack the sequence of images.
for data_type in self.image_data_types:
for idx in range(len(data[data_type])):
data[data_type][idx] = data[data_type][idx].unsqueeze(0)
data[data_type] = torch.cat(data[data_type], dim=0)
if not self.is_video_dataset:
# Remove any extra dimensions.
for data_type in self.image_data_types:
if data_type in data:
data[data_type] = data[data_type].squeeze(0)
# Prepare output.
data['driving_images'] = data['videos'][self.few_shot_K:]
data['source_images'] = data['videos'][:self.few_shot_K]
data.pop('videos')
data['is_flipped'] = is_flipped
data['key'] = keys
data['original_h_w'] = torch.IntTensor([
self.augmentor.original_h, self.augmentor.original_w])
# Apply full data ops.
data = self.apply_ops(data, self.full_data_ops, full_data=True)
return data
def __getitem__(self, index):
return self._getitem(index)
|