Spaces:
Running
Running
File size: 6,554 Bytes
d015578 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
import cv2
import random
import numpy as np
import spiga.data.loaders.dl_config as dl_cfg
import spiga.data.loaders.dataloader as dl
import spiga.data.visualize.plotting as plot
def inspect_parser():
import argparse
pars = argparse.ArgumentParser(description='Data augmentation and dataset visualization. '
'Press Q to quit,'
'N to visualize the next image'
' and any other key to visualize the next default data.')
pars.add_argument('database', type=str,
choices=['wflw', '300wpublic', '300wprivate', 'cofw68', 'merlrav'], help='Database name')
pars.add_argument('-a', '--anns', type=str, default='train', help='Annotation type: test, train or valid')
pars.add_argument('-np', '--nopose', action='store_false', default=True, help='Avoid pose generation')
pars.add_argument('-c', '--clean', action='store_true', help='Process without data augmentation for train')
pars.add_argument('--shape', nargs='+', type=int, default=[256, 256], help='Image cropped shape (W,H)')
pars.add_argument('--img', nargs='+', type=int, default=None, help='Select specific image ids')
return pars.parse_args()
class DatasetInspector:
def __init__(self, database, anns_type, data_aug=True, pose=True, image_shape=(256,256)):
data_config = dl_cfg.AlignConfig(database, anns_type)
data_config.image_size = image_shape
data_config.ftmap_size = image_shape
data_config.generate_pose = pose
if not data_aug:
data_config.aug_names = []
self.data_config = data_config
dataloader, dataset = dl.get_dataloader(1, data_config, debug=True)
self.dataset = dataset
self.dataloader = dataloader
self.colors_dft = {'lnd': (plot.GREEN, plot.RED), 'pose': (plot.BLUE, plot.GREEN, plot.RED)}
def show_dataset(self, ids_list=None):
if ids_list is None:
ids = self.get_idx(shuffle=self.data_config.shuffle)
else:
ids = ids_list
for img_id in ids:
data_dict = self.dataset[img_id]
crop_imgs, full_img = self.plot_features(data_dict)
# Plot crop
if 'merge' in crop_imgs.keys():
crop = crop_imgs['merge']
else:
crop = crop_imgs['lnd']
cv2.imshow('crop', crop)
# Plot full
cv2.imshow('image', full_img['lnd'])
key = cv2.waitKey()
if key == ord('q'):
break
def plot_features(self, data_dict, colors=None):
# Init variables
crop_imgs = {}
full_imgs = {}
if colors is None:
colors = self.colors_dft
# Cropped image
image = data_dict['image']
landmarks = data_dict['landmarks']
visible = data_dict['visible']
if np.any(np.isnan(visible)):
visible = None
mask = data_dict['mask_ldm']
# Full image
if 'image_ori' in data_dict.keys():
image_ori = data_dict['image_ori']
else:
image_ori = cv2.imread(data_dict['imgpath'])
landmarks_ori = data_dict['landmarks_ori']
visible_ori = data_dict['visible_ori']
if np.any(np.isnan(visible_ori)):
visible_ori = None
mask_ori = data_dict['mask_ldm_ori']
# Plot landmarks
crop_imgs['lnd'] = self._plot_lnd(image, landmarks, visible, mask, colors=colors['lnd'])
full_imgs['lnd'] = self._plot_lnd(image_ori, landmarks_ori, visible_ori, mask_ori, colors=colors['lnd'])
if self.data_config.generate_pose:
rot, trl, cam_matrix = self._extract_pose(data_dict)
# Plot pose
crop_imgs['pose'] = plot.draw_pose(image, rot, trl, cam_matrix, euler=True, colors=colors['pose'])
# Plot merge features
crop_imgs['merge'] = plot.draw_pose(crop_imgs['lnd'], rot, trl, cam_matrix, euler=True, colors=colors['pose'])
return crop_imgs, full_imgs
def get_idx(self, shuffle=False):
ids = list(range(len(self.dataset)))
if shuffle:
random.shuffle(ids)
return ids
def reload_dataset(self, data_config=None):
if data_config is None:
data_config = self.data_config
dataloader, dataset = dl.get_dataloader(1, data_config, debug=True)
self.dataset = dataset
self.dataloader = dataloader
def _extract_pose(self, data_dict):
# Rotation and translation matrix
pose = data_dict['pose']
rot = pose[:3]
trl = pose[3:]
# Camera matrix
cam_matrix = data_dict['cam_matrix']
# Check for ground truth anns
if 'headpose_ori' in data_dict.keys():
if len(self.data_config.aug_names) == 0:
print('Image headpose generated by ground truth data')
pose_ori = data_dict['headpose_ori']
rot = pose_ori
return rot, trl, cam_matrix
def _plot_lnd(self, image, landmarks, visible, mask, max_shape_thr=720, colors=None):
if colors is None:
colors = self.colors_dft['lnd']
# Full image plots
W, H, C = image.shape
# Original image resize if need it
if W > max_shape_thr or H > max_shape_thr:
max_shape = max(W, H)
scale_factor = max_shape_thr / max_shape
resize_shape = (int(H * scale_factor), int(W * scale_factor))
image_out = plot.draw_landmarks(image, landmarks, visible=visible, mask=mask,
thick_scale=1 / scale_factor, colors=colors)
image_out = cv2.resize(image_out, resize_shape)
else:
image_out = plot.draw_landmarks(image, landmarks, visible=visible, mask=mask, colors=colors)
return image_out
if __name__ == '__main__':
args = inspect_parser()
data_aug = True
database = args.database
anns_type = args.anns
pose = args.nopose
select_img = args.img
if args.clean:
data_aug = False
if len(args.shape) != 2:
raise ValueError('--shape requires two values: width and height. Ej: --shape 256 256')
else:
img_shape = tuple(args.shape)
visualizer = DatasetInspector(database, anns_type, data_aug=data_aug, pose=pose, image_shape=img_shape)
visualizer.show_dataset(ids_list=select_img)
|