venite's picture
initial
f670afc
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
"""Utils for the few shot vid2vid model."""
import random
import numpy as np
import torch
import torch.nn.functional as F
def resample(image, flow):
r"""Resamples an image using the provided flow.
Args:
image (NxCxHxW tensor) : Image to resample.
flow (Nx2xHxW tensor) : Optical flow to resample the image.
Returns:
output (NxCxHxW tensor) : Resampled image.
"""
assert flow.shape[1] == 2
b, c, h, w = image.size()
grid = get_grid(b, (h, w))
flow = torch.cat([flow[:, 0:1, :, :] / ((w - 1.0) / 2.0),
flow[:, 1:2, :, :] / ((h - 1.0) / 2.0)], dim=1)
final_grid = (grid + flow).permute(0, 2, 3, 1)
try:
output = F.grid_sample(image, final_grid, mode='bilinear',
padding_mode='border', align_corners=True)
except Exception:
output = F.grid_sample(image, final_grid, mode='bilinear',
padding_mode='border')
return output
def get_grid(batchsize, size, minval=-1.0, maxval=1.0):
r"""Get a grid ranging [-1, 1] of 2D/3D coordinates.
Args:
batchsize (int) : Batch size.
size (tuple) : (height, width) or (depth, height, width).
minval (float) : minimum value in returned grid.
maxval (float) : maximum value in returned grid.
Returns:
t_grid (4D tensor) : Grid of coordinates.
"""
if len(size) == 2:
rows, cols = size
elif len(size) == 3:
deps, rows, cols = size
else:
raise ValueError('Dimension can only be 2 or 3.')
x = torch.linspace(minval, maxval, cols)
x = x.view(1, 1, 1, cols)
x = x.expand(batchsize, 1, rows, cols)
y = torch.linspace(minval, maxval, rows)
y = y.view(1, 1, rows, 1)
y = y.expand(batchsize, 1, rows, cols)
t_grid = torch.cat([x, y], dim=1)
if len(size) == 3:
z = torch.linspace(minval, maxval, deps)
z = z.view(1, 1, deps, 1, 1)
z = z.expand(batchsize, 1, deps, rows, cols)
t_grid = t_grid.unsqueeze(2).expand(batchsize, 2, deps, rows, cols)
t_grid = torch.cat([t_grid, z], dim=1)
t_grid.requires_grad = False
return t_grid.to('cuda')
def pick_image(images, idx):
r"""Pick the image among images according to idx.
Args:
images (B x N x C x H x W tensor or list of tensors) : N images.
idx (B tensor) : indices to select.
Returns:
image (B x C x H x W) : Selected images.
"""
if type(images) == list:
return [pick_image(r, idx) for r in images]
if idx is None:
return images[:, 0]
elif type(idx) == int:
return images[:, idx]
idx = idx.long().view(-1, 1, 1, 1, 1)
image = images.gather(1, idx.expand_as(images)[:, 0:1])[:, 0]
return image
def crop_face_from_data(cfg, is_inference, data):
r"""Crop the face regions in input data and resize to the target size.
This is for training face datasets.
Args:
cfg (obj): Data configuration.
is_inference (bool): Is doing inference or not.
data (dict): Input data.
Returns:
data (dict): Cropped data.
"""
label = data['label'] if 'label' in data else None
image = data['images']
landmarks = data['landmarks-dlib68_xy']
ref_labels = data['few_shot_label'] if 'few_shot_label' in data else None
ref_images = data['few_shot_images']
ref_landmarks = data['few_shot_landmarks-dlib68_xy']
img_size = image.shape[-2:]
h, w = cfg.output_h_w.split(',')
h, w = int(h), int(w)
# When doing inference, need to sync common attributes like crop coodinates
# between different workers, so all workers crop the same region.
if 'common_attr' in data and 'crop_coords' in data['common_attr']:
# Has been computed before, reusing the previous one.
crop_coords, ref_crop_coords = data['common_attr']['crop_coords']
else:
# Is the first frame, need to compute the bbox.
ref_crop_coords, scale = get_face_bbox_for_data(
ref_landmarks[0], img_size, None, is_inference)
crop_coords, _ = get_face_bbox_for_data(
landmarks[0], img_size, scale, is_inference)
# Crop the images according to the bbox and resize them to target size.
label, image = crop_and_resize([label, image], crop_coords, (h, w))
ref_labels, ref_images = crop_and_resize([ref_labels, ref_images],
ref_crop_coords, (h, w))
data['images'], data['few_shot_images'] = image, ref_images
if label is not None:
data['label'], data['few_shot_label'] = label, ref_labels
if is_inference:
if 'common_attr' not in data:
data['common_attr'] = dict()
data['common_attr']['crop_coords'] = crop_coords, ref_crop_coords
return data
def get_face_bbox_for_data(keypoints, orig_img_size, scale, is_inference):
r"""Get the bbox coordinates for face region.
Args:
keypoints (Nx2 tensor): Facial landmarks.
orig_img_size (int tuple): Height and width of the input image size.
scale (float): When training, randomly scale the crop size for
augmentation.
is_inference (bool): Is doing inference or not.
Returns:
crop_coords (list of int): bbox for face region.
scale (float): Also returns scale to ensure reference and target frames
are croppped using the same scale.
"""
min_y, max_y = int(keypoints[:, 1].min()), int(keypoints[:, 1].max())
min_x, max_x = int(keypoints[:, 0].min()), int(keypoints[:, 0].max())
x_cen, y_cen = (min_x + max_x) // 2, (min_y + max_y) // 2
H, W = orig_img_size
w = h = (max_x - min_x)
if not is_inference:
# During training, randomly jitter the cropping position by offset
# amount for augmentation.
offset_max = 0.2
offset = [np.random.uniform(-offset_max, offset_max),
np.random.uniform(-offset_max, offset_max)]
# Also augment the crop size.
if scale is None:
scale_max = 0.2
scale = [np.random.uniform(1 - scale_max, 1 + scale_max),
np.random.uniform(1 - scale_max, 1 + scale_max)]
w *= scale[0]
h *= scale[1]
x_cen += int(offset[0] * w)
y_cen += int(offset[1] * h)
# Get the cropping coordinates.
x_cen = max(w, min(W - w, x_cen))
y_cen = max(h * 1.25, min(H - h * 0.75, y_cen))
min_x = x_cen - w
min_y = y_cen - h * 1.25
max_x = min_x + w * 2
max_y = min_y + h * 2
crop_coords = [min_y, max_y, min_x, max_x]
return [int(x) for x in crop_coords], scale
def crop_person_from_data(cfg, is_inference, data):
r"""Crop the person regions in data and resize to the target size.
This is for training full body datasets.
Args:
cfg (obj): Data configuration.
is_inference (bool): Is doing inference or not.
data (dict): Input data.
Returns:
data (dict): Cropped data.
"""
label = data['label']
image = data['images']
use_few_shot = 'few_shot_label' in data
if use_few_shot:
ref_labels = data['few_shot_label']
ref_images = data['few_shot_images']
img_size = image.shape[-2:]
output_h, output_w = cfg.output_h_w.split(',')
output_h, output_w = int(output_h), int(output_w)
output_aspect_ratio = output_w / output_h
if 'human_instance_maps' in data:
# Remove other people in the DensePose map except for the current
# target.
label = remove_other_ppl(label, data['human_instance_maps'])
if use_few_shot:
ref_labels = remove_other_ppl(ref_labels,
data['few_shot_human_instance_maps'])
# Randomly jitter the crop position by offset amount for augmentation.
offset = ref_offset = None
if not is_inference:
offset = np.random.randn(2) * 0.05
offset = np.minimum(1, np.maximum(-1, offset))
ref_offset = np.random.randn(2) * 0.02
ref_offset = np.minimum(1, np.maximum(-1, ref_offset))
# Randomly scale the crop size for augmentation.
# Final cropped size = person height * scale.
scale = ref_scale = 1.5
if not is_inference:
scale = min(2, max(1, scale + np.random.randn() * 0.05))
ref_scale = min(2, max(1, ref_scale + np.random.randn() * 0.02))
# When doing inference, need to sync common attributes like crop coodinates
# between different workers, so all workers crop the same region.
if 'common_attr' in data:
# Has been computed before, reusing the previous one.
crop_coords, ref_crop_coords = data['common_attr']['crop_coords']
else:
# Is the first frame, need to compute the bbox.
crop_coords = get_person_bbox_for_data(label, img_size, scale,
output_aspect_ratio, offset)
if use_few_shot:
ref_crop_coords = get_person_bbox_for_data(
ref_labels, img_size, ref_scale,
output_aspect_ratio, ref_offset)
else:
ref_crop_coords = None
# Crop the images according to the bbox and resize them to target size.
label = crop_and_resize(label, crop_coords, (output_h, output_w), 'nearest')
image = crop_and_resize(image, crop_coords, (output_h, output_w))
if use_few_shot:
ref_labels = crop_and_resize(ref_labels, ref_crop_coords,
(output_h, output_w), 'nearest')
ref_images = crop_and_resize(ref_images, ref_crop_coords,
(output_h, output_w))
data['label'], data['images'] = label, image
if use_few_shot:
data['few_shot_label'], data['few_shot_images'] = ref_labels, ref_images
if 'human_instance_maps' in data:
del data['human_instance_maps']
if 'few_shot_human_instance_maps' in data:
del data['few_shot_human_instance_maps']
if is_inference:
data['common_attr'] = dict()
data['common_attr']['crop_coords'] = crop_coords, ref_crop_coords
return data
def get_person_bbox_for_data(pose_map, orig_img_size, scale=1.5,
crop_aspect_ratio=1, offset=None):
r"""Get the bbox (pixel coordinates) to crop for person body region.
Args:
pose_map (NxCxHxW tensor): Input pose map.
orig_img_size (int tuple): Height and width of the input image size.
scale (float): When training, randomly scale the crop size for
augmentation.
crop_aspect_ratio (float): Output aspect ratio,
offset (list of float): Offset for crop position.
Returns:
crop_coords (list of int): bbox for body region.
"""
H, W = orig_img_size
assert pose_map.dim() == 4
nonzero_indices = (pose_map[:, :3] > 0).nonzero(as_tuple=False)
if nonzero_indices.size(0) == 0:
bw = int(H * crop_aspect_ratio // 2)
return [0, H, W // 2 - bw, W // 2 + bw]
y_indices, x_indices = nonzero_indices[:, 2], nonzero_indices[:, 3]
y_min, y_max = y_indices.min().item(), y_indices.max().item()
x_min, x_max = x_indices.min().item(), x_indices.max().item()
y_cen = int(y_min + y_max) // 2
x_cen = int(x_min + x_max) // 2
y_len = y_max - y_min
x_len = x_max - x_min
# bh, bw: half of height / width of final cropped size.
bh = int(min(H, max(H // 2, y_len * scale))) // 2
bh = max(bh, int(x_len * scale / crop_aspect_ratio) // 2)
bw = int(bh * crop_aspect_ratio)
# Randomly offset the cropped position for augmentation.
if offset is not None:
x_cen += int(offset[0] * bw)
y_cen += int(offset[1] * bh)
x_cen = max(bw, min(W - bw, x_cen))
y_cen = max(bh, min(H - bh, y_cen))
return [(y_cen - bh), (y_cen + bh), (x_cen - bw), (x_cen + bw)]
def crop_and_resize(img, coords, size=None, method='bilinear'):
r"""Crop the image using the given coordinates and resize to target size.
Args:
img (tensor or list of tensors): Input image.
coords (list of int): Pixel coordinates to crop.
size (list of int): Output size.
method (str): Interpolation method.
Returns:
img (tensor or list of tensors): Output image.
"""
if isinstance(img, list):
return [crop_and_resize(x, coords, size, method) for x in img]
if img is None:
return None
min_y, max_y, min_x, max_x = coords
img = img[:, :, min_y:max_y, min_x:max_x]
if size is not None:
if method == 'nearest':
img = F.interpolate(img, size=size, mode=method)
else:
img = F.interpolate(img, size=size, mode=method,
align_corners=False)
return img
def remove_other_ppl(labels, densemasks):
r"""Remove other people in the label map except for the current target
by looking at the id in the densemask map.
Args:
labels (NxCxHxW tensor): Input labels.
densemasks (Nx1xHxW tensor): Densemask maps.
Returns:
labels (NxCxHxW tensor): Output labels.
"""
densemasks = densemasks[:, 0:1] * 255
for idx in range(labels.shape[0]):
label, densemask = labels[idx], densemasks[idx]
# Get OpenPose and find the person id in Densemask that has the most
# overlap with the person in OpenPose result.
openpose = label[3:]
valid = (openpose[0] > 0) | (openpose[1] > 0) | (openpose[2] > 0)
dp_valid = densemask[valid.unsqueeze(0)]
if dp_valid.shape[0]:
ind = np.bincount(dp_valid).argmax()
# Remove all other people that have different indices.
label = label * (densemask == ind).float()
labels[idx] = label
return labels
def select_object(data, obj_indices=None):
r"""Select the object/person in the dict according to the object index.
Currently it's used to select the target person in OpenPose dict.
Args:
data (dict): Input data.
obj_indices (list of int): Indices for the objects to select.
Returns:
data (dict): Output data.
"""
op_keys = ['poses-openpose', 'captions-clip']
for op_key in op_keys:
if op_key in data:
for i in range(len(data[op_key])):
# data[op_key] is a list of dicts for different frames.
# people = data[op_key][i]['people']
people = data[op_key][i]
# "people" is a list of people dicts found by OpenPose. We will
# use the obj_index to get the target person from the list, and
# write it back to the dict.
# data[op_key][i]['people'] = [people[obj_indices[i]]]
if obj_indices is not None:
data[op_key][i] = people[obj_indices[i]]
else:
if op_key == 'poses-openpose':
data[op_key][i] = people[0]
else:
idx = random.randint(0, len(people) - 1)
data[op_key][i] = people[idx]
return data
def concat_frames(prev, now, n_frames):
r"""Concat previous and current frames and only keep the latest $(n_frames).
If concatenated frames are longer than $(n_frames), drop the oldest one.
Args:
prev (NxTxCxHxW tensor): Tensor for previous frames.
now (NxCxHxW tensor): Tensor for current frame.
n_frames (int): Max number of frames to store.
Returns:
result (NxTxCxHxW tensor): Updated tensor.
"""
now = now.unsqueeze(1)
if prev is None:
return now
if prev.shape[1] == n_frames:
prev = prev[:, 1:]
return torch.cat([prev, now], dim=1)
def combine_fg_mask(fg_mask, ref_fg_mask, has_fg):
r"""Get the union of target and reference foreground masks.
Args:
fg_mask (tensor): Foreground mask for target image.
ref_fg_mask (tensor): Foreground mask for reference image.
has_fg (bool): Whether the image can be classified into fg/bg.
Returns:
output (tensor or int): Combined foreground mask.
"""
return ((fg_mask > 0) | (ref_fg_mask > 0)).float() if has_fg else 1
def get_fg_mask(densepose_map, has_fg):
r"""Obtain the foreground mask for pose sequences, which only includes
the human. This is done by looking at the body part map from DensePose.
Args:
densepose_map (NxCxHxW tensor): DensePose map.
has_fg (bool): Whether data has foreground or not.
Returns:
mask (Nx1xHxW tensor): fg mask.
"""
if type(densepose_map) == list:
return [get_fg_mask(label, has_fg) for label in densepose_map]
if not has_fg or densepose_map is None:
return 1
if densepose_map.dim() == 5:
densepose_map = densepose_map[:, 0]
# Get the body part map from DensePose.
mask = densepose_map[:, 2:3]
# Make the mask slightly larger.
mask = torch.nn.MaxPool2d(15, padding=7, stride=1)(mask)
mask = (mask > -1).float()
return mask
def get_part_mask(densepose_map):
r"""Obtain mask of different body parts of humans. This is done by
looking at the body part map from DensePose.
Args:
densepose_map (NxCxHxW tensor): DensePose map.
Returns:
mask (NxKxHxW tensor): Body part mask, where K is the number of parts.
"""
# Groups of body parts. Each group contains IDs of body part labels in
# DensePose. The 9 groups here are: background, torso, hands, feet,
# upper legs, lower legs, upper arms, lower arms, head.
part_groups = [[0], [1, 2], [3, 4], [5, 6], [7, 9, 8, 10], [11, 13, 12, 14],
[15, 17, 16, 18], [19, 21, 20, 22], [23, 24]]
n_parts = len(part_groups)
need_reshape = densepose_map.dim() == 4
if need_reshape:
bo, t, h, w = densepose_map.size()
densepose_map = densepose_map.view(-1, h, w)
b, h, w = densepose_map.size()
part_map = (densepose_map / 2 + 0.5) * 24
assert (part_map >= 0).all() and (part_map < 25).all()
mask = torch.cuda.ByteTensor(b, n_parts, h, w).fill_(0)
for i in range(n_parts):
for j in part_groups[i]:
# Account for numerical errors.
mask[:, i] = mask[:, i] | (
(part_map > j - 0.1) & (part_map < j + 0.1)).byte()
if need_reshape:
mask = mask.view(bo, t, -1, h, w)
return mask.float()
def get_face_mask(densepose_map):
r"""Obtain mask of faces.
Args:
densepose_map (3D or 4D tensor): DensePose map.
Returns:
mask (3D or 4D tensor): Face mask.
"""
need_reshape = densepose_map.dim() == 4
if need_reshape:
bo, t, h, w = densepose_map.size()
densepose_map = densepose_map.view(-1, h, w)
b, h, w = densepose_map.size()
part_map = (densepose_map / 2 + 0.5) * 24
assert (part_map >= 0).all() and (part_map < 25).all()
if densepose_map.is_cuda:
mask = torch.cuda.ByteTensor(b, h, w).fill_(0)
else:
mask = torch.ByteTensor(b, h, w).fill_(0)
for j in [23, 24]:
mask = mask | ((part_map > j - 0.1) & (part_map < j + 0.1)).byte()
if need_reshape:
mask = mask.view(bo, t, h, w)
return mask.float()
def extract_valid_pose_labels(pose_map, pose_type, remove_face_labels,
do_remove=True):
r"""Remove some labels (e.g. face regions) in the pose map if necessary.
Args:
pose_map (3D, 4D or 5D tensor): Input pose map.
pose_type (str): 'both' or 'open'.
remove_face_labels (bool): Whether to remove labels for the face region.
do_remove (bool): Do remove face labels.
Returns:
pose_map (3D, 4D or 5D tensor): Output pose map.
"""
if pose_map is None:
return pose_map
if type(pose_map) == list:
return [extract_valid_pose_labels(p, pose_type, remove_face_labels,
do_remove) for p in pose_map]
orig_dim = pose_map.dim()
assert (orig_dim >= 3 and orig_dim <= 5)
if orig_dim == 3:
pose_map = pose_map.unsqueeze(0).unsqueeze(0)
elif orig_dim == 4:
pose_map = pose_map.unsqueeze(0)
if pose_type == 'open':
# If input is only openpose, remove densepose part.
pose_map = pose_map[:, :, 3:]
elif remove_face_labels and do_remove:
# Remove face part for densepose input.
densepose, openpose = pose_map[:, :, :3], pose_map[:, :, 3:]
face_mask = get_face_mask(pose_map[:, :, 2]).unsqueeze(2)
pose_map = torch.cat([densepose * (1 - face_mask) - face_mask,
openpose], dim=2)
if orig_dim == 3:
pose_map = pose_map[0, 0]
elif orig_dim == 4:
pose_map = pose_map[0]
return pose_map
def normalize_faces(keypoints, ref_keypoints,
dist_scale_x=None, dist_scale_y=None):
r"""Normalize face keypoints w.r.t. the reference face keypoints.
Args:
keypoints (Kx2 numpy array): target facial keypoints.
ref_keypoints (Kx2 numpy array): reference facial keypoints.
Returns:
keypoints (Kx2 numpy array): normalized facial keypoints.
"""
if keypoints.shape[0] == 68:
central_keypoints = [8]
add_upper_face = False
part_list = [[0, 16], [1, 15], [2, 14], [3, 13], [4, 12],
[5, 11], [6, 10], [7, 9, 8],
[17, 26], [18, 25], [19, 24], [20, 23], [21, 22],
[27], [28], [29], [30], [31, 35], [32, 34], [33],
[36, 45], [37, 44], [38, 43], [39, 42], [40, 47], [41, 46],
[48, 54], [49, 53], [50, 52], [51], [55, 59], [56, 58],
[57],
[60, 64], [61, 63], [62], [65, 67], [66]
]
if add_upper_face:
part_list += [[68, 82], [69, 81], [70, 80], [71, 79], [72, 78],
[73, 77], [74, 76, 75]]
elif keypoints.shape[0] == 126:
central_keypoints = [16]
part_list = [[i] for i in range(126)]
else:
raise ValueError('Input keypoints type not supported.')
face_cen = np.mean(keypoints[central_keypoints, :], axis=0)
ref_face_cen = np.mean(ref_keypoints[central_keypoints, :], axis=0)
def get_mean_dists(pts, face_cen):
r"""Get the mean xy distances of keypoints wrt face center."""
mean_dists_x, mean_dists_y = [], []
pts_cen = np.mean(pts, axis=0)
for p, pt in enumerate(pts):
mean_dists_x.append(np.linalg.norm(pt - pts_cen))
mean_dists_y.append(np.linalg.norm(pts_cen - face_cen))
mean_dist_x = sum(mean_dists_x) / len(mean_dists_x) + 1e-3
mean_dist_y = sum(mean_dists_y) / len(mean_dists_y) + 1e-3
return mean_dist_x, mean_dist_y
if dist_scale_x is None:
dist_scale_x, dist_scale_y = [None] * len(part_list), \
[None] * len(part_list)
for i, pts_idx in enumerate(part_list):
pts = keypoints[pts_idx]
if dist_scale_x[i] is None:
ref_pts = ref_keypoints[pts_idx]
mean_dist_x, mean_dist_y = get_mean_dists(pts, face_cen)
ref_dist_x, ref_dist_y = get_mean_dists(ref_pts, ref_face_cen)
dist_scale_x[i] = ref_dist_x / mean_dist_x
dist_scale_y[i] = ref_dist_y / mean_dist_y
pts_cen = np.mean(pts, axis=0)
pts = (pts - pts_cen) * dist_scale_x[i] + \
(pts_cen - face_cen) * dist_scale_y[i] + face_cen
keypoints[pts_idx] = pts
return keypoints, [dist_scale_x, dist_scale_y]
def crop_face_from_output(data_cfg, image, input_label, crop_smaller=0):
r"""Crop out the face region of the image (and resize if necessary to feed
into generator/discriminator).
Args:
data_cfg (obj): Data configuration.
image (NxC1xHxW tensor or list of tensors): Image to crop.
input_label (NxC2xHxW tensor): Input label map.
crop_smaller (int): Number of pixels to crop slightly smaller region.
Returns:
output (NxC1xHxW tensor or list of tensors): Cropped image.
"""
if type(image) == list:
return [crop_face_from_output(data_cfg, im, input_label, crop_smaller)
for im in image]
output = None
face_size = image.shape[-2] // 32 * 8
for i in range(input_label.size(0)):
ys, ye, xs, xe = get_face_bbox_for_output(data_cfg,
input_label[i:i + 1],
crop_smaller=crop_smaller)
output_i = F.interpolate(image[i:i + 1, -3:, ys:ye, xs:xe],
size=(face_size, face_size), mode='bilinear',
align_corners=True)
# output_i = image[i:i + 1, -3:, ys:ye, xs:xe]
output = torch.cat([output, output_i]) if i != 0 else output_i
return output
def get_face_bbox_for_output(data_cfg, pose, crop_smaller=0):
r"""Get pixel coordinates of the face bounding box.
Args:
data_cfg (obj): Data configuration.
pose (NxCxHxW tensor): Pose label map.
crop_smaller (int): Number of pixels to crop slightly smaller region.
Returns:
output (list of int): Face bbox.
"""
if pose.dim() == 3:
pose = pose.unsqueeze(0)
elif pose.dim() == 5:
pose = pose[-1, -1:]
_, _, h, w = pose.size()
use_openpose = 'pose_maps-densepose' not in data_cfg.input_labels
if use_openpose: # Use openpose face keypoints to identify face region.
for input_type in data_cfg.input_types:
if 'poses-openpose' in input_type:
num_ch = input_type['poses-openpose'].num_channels
if num_ch > 3:
face = (pose[:, -1] > 0).nonzero(as_tuple=False)
else:
raise ValueError('Not implemented yet.')
else: # Use densepose labels.
face = (pose[:, 2] > 0.9).nonzero(as_tuple=False)
ylen = xlen = h // 32 * 8
if face.size(0):
y, x = face[:, 1], face[:, 2]
ys, ye = y.min().item(), y.max().item()
xs, xe = x.min().item(), x.max().item()
if use_openpose:
xc, yc = (xs + xe) // 2, (ys * 3 + ye * 2) // 5
ylen = int((xe - xs) * 2.5)
else:
xc, yc = (xs + xe) // 2, (ys + ye) // 2
ylen = int((ye - ys) * 1.25)
ylen = xlen = min(w, max(32, ylen))
yc = max(ylen // 2, min(h - 1 - ylen // 2, yc))
xc = max(xlen // 2, min(w - 1 - xlen // 2, xc))
else:
yc = h // 4
xc = w // 2
ys, ye = yc - ylen // 2, yc + ylen // 2
xs, xe = xc - xlen // 2, xc + xlen // 2
if crop_smaller != 0: # Crop slightly smaller region inside face.
ys += crop_smaller
xs += crop_smaller
ye -= crop_smaller
xe -= crop_smaller
return [ys, ye, xs, xe]
def crop_hand_from_output(data_cfg, image, input_label):
r"""Crop out the hand region of the image.
Args:
data_cfg (obj): Data configuration.
image (NxC1xHxW tensor or list of tensors): Image to crop.
input_label (NxC2xHxW tensor): Input label map.
Returns:
output (NxC1xHxW tensor or list of tensors): Cropped image.
"""
if type(image) == list:
return [crop_hand_from_output(data_cfg, im, input_label)
for im in image]
output = None
for i in range(input_label.size(0)):
coords = get_hand_bbox_for_output(data_cfg, input_label[i:i + 1])
if coords:
for coord in coords:
ys, ye, xs, xe = coord
output_i = image[i:i + 1, -3:, ys:ye, xs:xe]
output = torch.cat([output, output_i]) \
if output is not None else output_i
return output
def get_hand_bbox_for_output(data_cfg, pose):
r"""Get coordinates of the hand bounding box.
Args:
data_cfg (obj): Data configuration.
pose (NxCxHxW tensor): Pose label map.
Returns:
output (list of int): Hand bbox.
"""
if pose.dim() == 3:
pose = pose.unsqueeze(0)
elif pose.dim() == 5:
pose = pose[-1, -1:]
_, _, h, w = pose.size()
ylen = xlen = h // 64 * 8
coords = []
colors = [[0.95, 0.5, 0.95], [0.95, 0.95, 0.5]]
for i, color in enumerate(colors):
if pose.shape[1] > 6: # Using one-hot encoding for openpose.
idx = -3 if i == 0 else -2
hand = (pose[:, idx] == 1).nonzero(as_tuple=False)
else:
raise ValueError('Not implemented yet.')
if hand.size(0):
y, x = hand[:, 1], hand[:, 2]
ys, ye, xs, xe = y.min().item(), y.max().item(), \
x.min().item(), x.max().item()
xc, yc = (xs + xe) // 2, (ys + ye) // 2
yc = max(ylen // 2, min(h - 1 - ylen // 2, yc))
xc = max(xlen // 2, min(w - 1 - xlen // 2, xc))
ys, ye, xs, xe = yc - ylen // 2, yc + ylen // 2, \
xc - xlen // 2, xc + xlen // 2
coords.append([ys, ye, xs, xe])
return coords
def pre_process_densepose(pose_cfg, pose_map, is_infer=False):
r"""Pre-process the DensePose part of input label map.
Args:
pose_cfg (obj): Pose data configuration.
pose_map (NxCxHxW tensor): Pose label map.
is_infer (bool): Is doing inference.
Returns:
pose_map (NxCxHxW tensor): Processed pose label map.
"""
part_map = pose_map[:, :, 2] * 255 # should be within [0-24]
assert (part_map >= 0).all() and (part_map < 25).all()
# Randomly drop some body part during training.
if not is_infer:
random_drop_prob = getattr(pose_cfg, 'random_drop_prob', 0)
else:
random_drop_prob = 0
if random_drop_prob > 0:
densepose_map = pose_map[:, :, :3]
for part_id in range(1, 25):
if (random.random() < random_drop_prob):
part_mask = abs(part_map - part_id) < 0.1
densepose_map[part_mask.unsqueeze(2).expand_as(
densepose_map)] = 0
pose_map[:, :, :3] = densepose_map
# Renormalize the DensePose channel from [0, 24] to [0, 255].
pose_map[:, :, 2] = pose_map[:, :, 2] * (255 / 24)
# Normalize from [0, 1] to [-1, 1].
pose_map = pose_map * 2 - 1
return pose_map
def random_roll(tensors):
r"""Randomly roll the input tensors along x and y dimensions. Also randomly
flip the tensors.
Args:
tensors (list of 4D tensors): Input tensors.
Returns:
output (list of 4D tensors): Rolled tensors.
"""
h, w = tensors[0].shape[2:]
ny = np.random.choice([np.random.randint(h//16),
h-np.random.randint(h//16)])
nx = np.random.choice([np.random.randint(w//16),
w-np.random.randint(w//16)])
flip = np.random.rand() > 0.5
return [roll(t, ny, nx, flip) for t in tensors]
def roll(t, ny, nx, flip=False):
r"""Roll and flip the tensor by specified amounts.
Args:
t (4D tensor): Input tensor.
ny (int): Amount to roll along y dimension.
nx (int): Amount to roll along x dimension.
flip (bool): Whether to flip input.
Returns:
t (4D tensor): Output tensor.
"""
t = torch.cat([t[:, :, -ny:], t[:, :, :-ny]], dim=2)
t = torch.cat([t[:, :, :, -nx:], t[:, :, :, :-nx]], dim=3)
if flip:
t = torch.flip(t, dims=[3])
return t
def detach(output):
r"""Detach tensors in the dict.
Args:
output (dict): Output dict.
Returns:
output (dict): Detached output dict.
"""
if type(output) == dict:
new_dict = dict()
for k, v in output.items():
new_dict[k] = detach(v)
return new_dict
elif type(output) == torch.Tensor:
return output.detach()
return output