import cv2 import numpy as np import torch from spiga.data.loaders.augmentors.modern_posit import PositPose from spiga.data.loaders.augmentors.heatmaps import Heatmaps from spiga.data.loaders.augmentors.boundary import AddBoundary from spiga.data.loaders.augmentors.landmarks import HorizontalFlipAug, RSTAug, OcclusionAug, \ LightingAug, BlurAug, TargetCropAug def get_transformers(data_config): # Data augmentation aug_names = data_config.aug_names augmentors = [] if 'flip' in aug_names: augmentors.append(HorizontalFlipAug(data_config.database.ldm_flip_order, data_config.hflip_prob)) if 'rotate_scale' in aug_names: augmentors.append(RSTAug(data_config.angle_range, data_config.scale_min, data_config.scale_max, data_config.trl_ratio)) if 'occlusion' in aug_names: augmentors.append(OcclusionAug(data_config.occluded_min_len, data_config.occluded_max_len, data_config.database.num_landmarks)) if 'lighting' in aug_names: augmentors.append(LightingAug(data_config.hsv_range_min, data_config.hsv_range_max)) if 'blur' in aug_names: augmentors.append(BlurAug(data_config.blur_prob, data_config.blur_kernel_range)) # Crop mandatory augmentors.append(TargetCropAug(data_config.image_size, data_config.ftmap_size, data_config.target_dist)) # Opencv style augmentors.append(ToOpencv()) # Gaussian heatmaps if 'heatmaps2D' in aug_names: augmentors.append(Heatmaps(data_config.database.num_landmarks, data_config.ftmap_size, data_config.sigma2D, norm=data_config.heatmap2D_norm)) if 'boundaries' in aug_names: augmentors.append(AddBoundary(num_landmarks=data_config.database.num_landmarks, map_size=data_config.ftmap_size, sigma=data_config.sigmaBD)) # Pose generator if data_config.generate_pose: augmentors.append(PositPose(data_config.database.ldm_ids, focal_ratio=data_config.focal_ratio, selected_ids=data_config.posit_ids, max_iter=data_config.posit_max_iter)) return augmentors class ToOpencv: def __call__(self, sample): # Convert in a numpy array and change to GBR image = np.array(sample['image']) image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) sample['image'] = image return sample class TargetCrop(TargetCropAug): def __init__(self, crop_size=256, target_dist=1.6): super(TargetCrop, self).__init__(crop_size, crop_size, target_dist) class AddModel3D(PositPose): def __init__(self, ldm_ids, ftmap_size=(256, 256), focal_ratio=1.5, totensor=False): super(AddModel3D, self).__init__(ldm_ids, focal_ratio=focal_ratio) img_bbox = [0, 0, ftmap_size[1], ftmap_size[0]] # Shapes given are inverted (y,x) self.cam_matrix = self._camera_matrix(img_bbox) if totensor: self.cam_matrix = torch.tensor(self.cam_matrix, dtype=torch.float) self.model3d_world = torch.tensor(self.model3d_world, dtype=torch.float) def __call__(self, sample={}): # Save intrinsic matrix and 3D model landmarks sample['cam_matrix'] = self.cam_matrix sample['model3d'] = self.model3d_world return sample