venite's picture
initial
f670afc
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import numpy as np
import cv2
import torch
from scipy.optimize import curve_fit
from scipy.signal import medfilt
import warnings
from imaginaire.utils.io import get_checkpoint
def connect_face_keypoints(resize_h, resize_w, crop_h, crop_w, original_h,
original_w, is_flipped, cfgdata, keypoints):
r"""Connect the face keypoints to edges and draw the sketch.
Args:
resize_h (int): Height the input image was resized to.
resize_w (int): Width the input image was resized to.
crop_h (int): Height the input image was cropped.
crop_w (int): Width the input image was cropped.
original_h (int): Original height of the input image.
original_w (int): Original width of the input image.
is_flipped (bool): Is the input image flipped.
cfgdata (obj): Data configuration.
keypoints (NxKx2 numpy array): Facial landmarks (with K keypoints).
Returns:
(list of HxWxC numpy array): Drawn label map.
"""
if hasattr(cfgdata, 'for_face_dataset'):
face_cfg = cfgdata.for_face_dataset
# Whether to add the upper part of face to label map.
add_upper_face = getattr(face_cfg, 'add_upper_face', False)
# Whether to add distance transform output to label map.
add_dist_map = getattr(face_cfg, 'add_distance_transform', False)
# Whether to add positional encoding to label map.
add_pos_encode = add_dist_map and getattr(
face_cfg, 'add_positional_encode', False)
else:
add_upper_face = add_dist_map = add_pos_encode = False
# Mapping from keypoint index to facial part.
part_list = [[list(range(0, 17)) + (
(list(range(68, 83)) + [0]) if add_upper_face else [])], # ai_emoji
[range(17, 22)], # right eyebrow
[range(22, 27)], # left eyebrow
[[28, 31], range(31, 36), [35, 28]], # nose
[[36, 37, 38, 39], [39, 40, 41, 36]], # right eye
[[42, 43, 44, 45], [45, 46, 47, 42]], # left eye
[range(48, 55), [54, 55, 56, 57, 58, 59, 48],
range(60, 65), [64, 65, 66, 67, 60]], # mouth and tongue
]
if add_upper_face:
pts = keypoints[:, :17, :].astype(np.int32)
baseline_y = (pts[:, 0:1, 1] + pts[:, -1:, 1]) / 2
upper_pts = pts[:, 1:-1, :].copy()
upper_pts[:, :, 1] = baseline_y + (
baseline_y - upper_pts[:, :, 1]) * 2 // 3
keypoints = np.hstack((keypoints, upper_pts[:, ::-1, :]))
edge_len = 3 # Interpolate 3 keypoints to form a curve when drawing edges.
bw = max(1, resize_h // 256) # Width of the stroke.
outputs = []
for t_idx in range(keypoints.shape[0]):
# Edge map for the face region from keypoints.
im_edges = np.zeros((resize_h, resize_w, 1), np.uint8)
im_dists = np.zeros((resize_h, resize_w, 0), np.uint8)
for edge_list in part_list:
for e, edge in enumerate(edge_list):
# Edge map for the current edge.
im_edge = np.zeros((resize_h, resize_w, 1), np.uint8)
# Divide a long edge into multiple small edges when drawing.
for i in range(0, max(1, len(edge) - 1), edge_len - 1):
sub_edge = edge[i:i + edge_len]
x = keypoints[t_idx, sub_edge, 0]
y = keypoints[t_idx, sub_edge, 1]
# Interp keypoints to get the curve shape.
curve_x, curve_y = interp_points(x, y)
draw_edge(im_edges, curve_x, curve_y, bw=bw)
if add_dist_map:
draw_edge(im_edge, curve_x, curve_y, bw=bw)
if add_dist_map:
# Add distance transform map on each facial part.
im_dist = cv2.distanceTransform(255 - im_edge,
cv2.DIST_L1, 3)
im_dist = np.clip((im_dist / 3), 0, 255)
im_dists = np.dstack((im_dists, im_dist))
if add_pos_encode and e == 0:
# Add positional encoding for the first edge.
from math import pi
im_pos = np.zeros((resize_h, resize_w, 0), np.float32)
for l in range(10): # noqa: E741
dist = (im_dist.astype(np.float32) - 127.5) / 127.5
sin = np.sin(pi * (2 ** l) * dist)
cos = np.cos(pi * (2 ** l) * dist)
im_pos = np.dstack((im_pos, sin, cos))
# Combine all components to form the final label map.
if add_dist_map:
im_edges = np.dstack((im_edges, im_dists))
im_edges = im_edges.astype(np.float32) / 255.0
if add_pos_encode:
im_edges = np.dstack((im_edges, im_pos))
outputs.append(im_edges)
return outputs
def normalize_and_connect_face_keypoints(cfg, is_inference, data):
r"""Normalize face keypoints w.r.t. reference face keypoints and connect
keypoints to form 2D images.
Args:
cfg (obj): Data configuration.
is_inference (bool): Is doing inference or not.
data (dict): Input data.
Returns:
(dict): Output data.
"""
assert is_inference
resize_h, resize_w = data['images'][0].shape[-2:]
keypoints = data['label'].numpy()[0]
ref_keypoints = data['few_shot_label'].numpy()[0]
# Get the normalization params and prev data if it's been computed before.
dist_scales = prev_keypoints = None
if 'common_attr' in data and 'prev_data' in data['common_attr']:
dist_scales = data['common_attr']['dist_scales']
prev_keypoints = data['common_attr']['prev_data']
def concat(prev, now, t):
r"""Concat prev and now frames in first dimension, up to t frames."""
if prev is None:
return now
return np.vstack([prev, now])[-t:]
# Normalize face keypoints w.r.t. reference face keypoints.
keypoints, dist_scales = \
normalize_face_keypoints(keypoints[0], ref_keypoints[0], dist_scales,
momentum=getattr(cfg.for_face_dataset,
'normalize_momentum', 0.9))
keypoints = keypoints[np.newaxis, :]
# Temporally smooth the face keypoints by median filtering.
ks = getattr(cfg.for_face_dataset, 'smooth_kernel_size', 5)
concat_keypoints = concat(prev_keypoints, keypoints, ks)
if ks > 1 and concat_keypoints.shape[0] == ks:
keypoints = smooth_face_keypoints(concat_keypoints, ks)
# Store the computed params.
if 'common_attr' not in data:
data['common_attr'] = dict()
data['common_attr']['dist_scales'] = dist_scales
data['common_attr']['prev_data'] = concat_keypoints
# Draw the keypoints to turn them into images.
labels = []
for kpt in [keypoints, ref_keypoints]:
label = connect_face_keypoints(resize_h, resize_w, None, None, None,
None, False, cfg, kpt)
labels += [torch.from_numpy(label[0]).permute(2, 0, 1).unsqueeze(0)]
data['label'], data['few_shot_label'] = labels
return data
def smooth_face_keypoints(concat_keypoints, ks):
r""" Temporally smooth the face keypoints by median filtering.
Args:
concat_keypoints (TxKx2 numpy array): Face keypoints to be filtered.
ks (int): Filter kernel size.
Returns:
(1xKx2 numpy array): Output face keypoints.
"""
# Median filtering.
filtered_keypoints = medfilt(concat_keypoints, kernel_size=[ks, 1, 1])
# Fill in any zero keypoints with the value from previous frame.
if (filtered_keypoints == 0).any():
for t in range(1, filtered_keypoints.shape[0]):
kpt_prev = filtered_keypoints[t - 1]
kpt_cur = filtered_keypoints[t]
kpt_max = np.maximum(kpt_cur, kpt_prev)
kpt_cur[kpt_cur == 0] = kpt_max[kpt_cur == 0]
filtered_keypoints[t] = kpt_cur
keypoints = filtered_keypoints[ks // 2: ks // 2 + 1]
return keypoints
def normalize_face_keypoints(keypoints, ref_keypoints, dist_scales=None,
momentum=0.9):
r"""Normalize face keypoints w.r.t. the reference face keypoints.
Args:
keypoints (Kx2 numpy array): Target facial keypoints to be normalized.
ref_keypoints (Kx2 numpy array): Reference facial keypoints.
dist_scales (list of list of floats): Normalization params.
momentum (float): Temporal momentum for the normalization params.
Returns:
(Kx2 numpy array): Normalized facial keypoints.
"""
if keypoints.shape[0] == 68:
central_keypoints = [8]
part_list = [[0, 16], [1, 15], [2, 14], [3, 13], [4, 12],
[5, 11], [6, 10], [7, 9, 8],
[17, 26], [18, 25], [19, 24], [20, 23], [21, 22],
[27], [28], [29], [30], [31, 35], [32, 34], [33],
[36, 45], [37, 44], [38, 43], [39, 42], [40, 47], [41, 46],
[48, 54], [49, 53], [50, 52], [51], [55, 59], [56, 58],
[57],
[60, 64], [61, 63], [62], [65, 67], [66]
]
else:
raise ValueError('Input keypoints type not supported.')
face_cen = np.mean(keypoints[central_keypoints, :], axis=0)
ref_face_cen = np.mean(ref_keypoints[central_keypoints, :], axis=0)
def get_mean_dists(pts, face_cen):
r"""Get mean distances of the points from face center."""
mean_dists_x, mean_dists_y = [], []
pts_cen = np.mean(pts, axis=0)
for p, pt in enumerate(pts):
mean_dists_x.append(np.linalg.norm(pt - pts_cen))
mean_dists_y.append(np.linalg.norm(pts_cen - face_cen))
mean_dist_x = sum(mean_dists_x) / len(mean_dists_x) + 1e-3
mean_dist_y = sum(mean_dists_y) / len(mean_dists_y) + 1e-3
return mean_dist_x, mean_dist_y
dist_scale_x, dist_scale_y = [None] * len(part_list), \
[None] * len(part_list)
if dist_scales is None:
dist_scale_x_prev = dist_scale_y_prev = img_scale = None
else:
dist_scale_x_prev, dist_scale_y_prev, img_scale = dist_scales
if img_scale is None:
img_scale = (keypoints[:, 0].max() - keypoints[:, 0].min()) \
/ (ref_keypoints[:, 0].max() - ref_keypoints[:, 0].min())
for i, pts_idx in enumerate(part_list):
pts = keypoints[pts_idx]
pts = pts[pts[:, 0] != 0]
if pts.shape[0]:
ref_pts = ref_keypoints[pts_idx]
mean_dist_x, mean_dist_y = get_mean_dists(pts, face_cen)
ref_dist_x, ref_dist_y = get_mean_dists(ref_pts, ref_face_cen)
dist_scale_x[i] = ref_dist_x / mean_dist_x * img_scale
dist_scale_y[i] = ref_dist_y / mean_dist_y * img_scale
if dist_scale_x_prev is not None:
dist_scale_x[i] = dist_scale_x_prev[i] * momentum + \
dist_scale_x[i] * (1 - momentum)
dist_scale_y[i] = dist_scale_y_prev[i] * momentum + \
dist_scale_y[i] * (1 - momentum)
pts_cen = np.mean(pts, axis=0)
pts = (pts - pts_cen) * dist_scale_x[i] + \
(pts_cen - face_cen) * dist_scale_y[i] + face_cen
keypoints[pts_idx] = pts
return keypoints, [dist_scale_x, dist_scale_y, img_scale]
def npy_to_tensor(keypoints):
r"""Convert numpy array to pytorch tensor."""
return torch.from_numpy(keypoints).unsqueeze(0)
def get_dlib_landmarks_from_image(
imgs, predictor_path='shape_predictor_68_face_landmarks.dat'):
r"""Get face keypoints from an image.
Args:
imgs (N x 3 x H x W tensor or N x H x W x 3 numpy array): Input images.
predictor_path (str): Path to the predictor model.
"""
import dlib
predictor_path = get_checkpoint(predictor_path,
url='1l9zT-AI1yKlfyAb_wl_RjLBSaiWQr8dr')
if type(imgs) == torch.Tensor:
imgs = ((imgs + 1) / 2 * 255).byte()
imgs = np.transpose(imgs.cpu().numpy(), (0, 2, 3, 1))
detector = dlib.get_frontal_face_detector()
predictor = dlib.shape_predictor(predictor_path)
points = np.zeros([imgs.shape[0], 68, 2], dtype=int)
for i in range(imgs.shape[0]):
img = imgs[i]
dets = detector(img, 1)
if len(dets) > 0:
# Only returns the first face.
shape = predictor(img, dets[0])
for b in range(68):
points[i, b, 0] = shape.part(b).x
points[i, b, 1] = shape.part(b).y
return points
def get_126_landmarks_from_image(imgs, landmarks_network):
r"""Get face keypoints from an image.
Args:
imgs (Nx3xHxW tensor or NxHxWx3 numpy array):
Input images.
landmarks_network (obj): The landmark detection network.
Return:
(Nx126x2 numpy array): Predicted landmarks.
"""
if type(imgs) == torch.Tensor:
imgs = ((imgs + 1) / 2 * 255).byte()
imgs = np.transpose(imgs.cpu().numpy(), (0, 2, 3, 1))
landmarks = []
for i in range(imgs.shape[0]):
img = imgs[i]
out_boxes, landmark = \
landmarks_network.get_face_boxes_and_landmarks(img)
if len(landmark) > 1:
# Pick the largest face in the image.
face_size_max = face_index = 0
for i, out_box in enumerate(out_boxes):
face_size = max(out_box[2] - out_box[0],
out_box[1] - out_box[1])
if face_size > face_size_max:
face_size_max = face_size
face_index = i
landmark = landmark[face_index]
elif len(landmark) == 1:
landmark = landmark[0]
else:
landmark = np.zeros((126, 2), dtype=np.float32)
landmarks += [landmark[np.newaxis]]
landmarks = np.vstack(landmarks).astype(np.float32)
return landmarks
def convert_face_landmarks_to_image(cfgdata, landmarks, output_size,
output_tensor=True, cpu_only=False):
r"""Convert the facial landmarks to a label map.
Args:
cfgdata (obj): Data configuration.
landmarks
output_size (tuple of int): H, W of output label map.
output_tensor (bool): Output tensors instead of numpy arrays.
cpu_only (bool): Output CPU tensor only.
Returns:
(NxCxHxW tensor or list of HxWxC numpy arrays): Label maps.
"""
h, w = output_size
labels = connect_face_keypoints(h, w, None, None, None, None, False,
cfgdata, landmarks)
if not output_tensor:
return labels
labels = [torch.from_numpy(label).permute(2, 0, 1).unsqueeze(0)
for label in labels]
labels = torch.cat(labels)
if cpu_only:
return labels
return labels.cuda()
def add_face_keypoints(label_map, image, keypoints):
r"""Add additional keypoints to label map.
Args:
label_map (Nx1xHxW tensor or None)
image (Nx3xHxW tensor)
keypoints (NxKx2 tensor)
"""
if label_map is None:
label_map = torch.zeros_like(image)[:, :1]
x, y = keypoints[:, :, 0], keypoints[:, :, 1]
h, w = image.shape[-2:]
x = ((x + 1) / 2 * w).long()
y = ((y + 1) / 2 * h).long()
bs = torch.arange(label_map.shape[0]).cuda().view(-1, 1).expand_as(x)
label_map[bs, :, y, x] = 1
return label_map
def draw_edge(im, x, y, bw=1, color=(255, 255, 255), draw_end_points=False):
r"""Set colors given a list of x and y coordinates for the edge.
Args:
im (HxWxC numpy array): Canvas to draw.
x (1D numpy array): x coordinates of the edge.
y (1D numpy array): y coordinates of the edge.
bw (int): Width of the stroke.
color (list or tuple of int): Color to draw.
draw_end_points (bool): Whether to draw end points of the edge.
"""
if x is not None and x.size:
h, w = im.shape[0], im.shape[1]
# Draw edge.
for i in range(-bw, bw):
for j in range(-bw, bw):
yy = np.maximum(0, np.minimum(h - 1, y + i))
xx = np.maximum(0, np.minimum(w - 1, x + j))
set_color(im, yy, xx, color)
# Draw endpoints.
if draw_end_points:
for i in range(-bw * 2, bw * 2):
for j in range(-bw * 2, bw * 2):
if (i ** 2) + (j ** 2) < (4 * bw ** 2):
yy = np.maximum(0, np.minimum(h - 1, np.array(
[y[0], y[-1]]) + i))
xx = np.maximum(0, np.minimum(w - 1, np.array(
[x[0], x[-1]]) + j))
set_color(im, yy, xx, color)
def set_color(im, yy, xx, color):
r"""Set pixels of the image to the given color.
Args:
im (HxWxC numpy array): Canvas to draw.
xx (1D numpy array): x coordinates of the pixels.
yy (1D numpy array): y coordinates of the pixels.
color (list or tuple of int): Color to draw.
"""
if type(color) != list and type(color) != tuple:
color = [color] * 3
if len(im.shape) == 3 and im.shape[2] == 3:
if (im[yy, xx] == 0).all():
im[yy, xx, 0], im[yy, xx, 1], im[yy, xx, 2] = \
color[0], color[1], color[2]
else:
for c in range(3):
im[yy, xx, c] = ((im[yy, xx, c].astype(float)
+ color[c]) / 2).astype(np.uint8)
else:
im[yy, xx] = color[0]
def interp_points(x, y):
r"""Given the start and end points, interpolate to get a curve/line.
Args:
x (1D array): x coordinates of the points to interpolate.
y (1D array): y coordinates of the points to interpolate.
Returns:
(dict):
- curve_x (1D array): x coordinates of the interpolated points.
- curve_y (1D array): y coordinates of the interpolated points.
"""
if abs(x[:-1] - x[1:]).max() < abs(y[:-1] - y[1:]).max():
curve_y, curve_x = interp_points(y, x)
if curve_y is None:
return None, None
else:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
try:
if len(x) < 3:
popt, _ = curve_fit(linear, x, y)
else:
popt, _ = curve_fit(func, x, y)
if abs(popt[0]) > 1:
return None, None
except Exception:
return None, None
if x[0] > x[-1]:
x = list(reversed(x))
y = list(reversed(y))
curve_x = np.linspace(x[0], x[-1], int(np.round(x[-1]-x[0])))
if len(x) < 3:
curve_y = linear(curve_x, *popt)
else:
curve_y = func(curve_x, *popt)
return curve_x.astype(int), curve_y.astype(int)
def func(x, a, b, c):
r"""Quadratic fitting function."""
return a * x**2 + b * x + c
def linear(x, a, b):
r"""Linear fitting function."""
return a * x + b