File size: 3,558 Bytes
d015578
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import cv2
import numpy as np
import torch

from spiga.data.loaders.augmentors.modern_posit import PositPose
from spiga.data.loaders.augmentors.heatmaps import Heatmaps
from spiga.data.loaders.augmentors.boundary import AddBoundary
from spiga.data.loaders.augmentors.landmarks import HorizontalFlipAug, RSTAug, OcclusionAug, \
                                              LightingAug, BlurAug, TargetCropAug


def get_transformers(data_config):

    # Data augmentation
    aug_names = data_config.aug_names
    augmentors = []

    if 'flip' in aug_names:
        augmentors.append(HorizontalFlipAug(data_config.database.ldm_flip_order, data_config.hflip_prob))
    if 'rotate_scale' in aug_names:
        augmentors.append(RSTAug(data_config.angle_range, data_config.scale_min,
                                 data_config.scale_max, data_config.trl_ratio))
    if 'occlusion' in aug_names:
        augmentors.append(OcclusionAug(data_config.occluded_min_len,
                                       data_config.occluded_max_len,
                                       data_config.database.num_landmarks))
    if 'lighting' in aug_names:
        augmentors.append(LightingAug(data_config.hsv_range_min, data_config.hsv_range_max))
    if 'blur' in aug_names:
        augmentors.append(BlurAug(data_config.blur_prob, data_config.blur_kernel_range))

    # Crop mandatory
    augmentors.append(TargetCropAug(data_config.image_size, data_config.ftmap_size, data_config.target_dist))
    # Opencv style
    augmentors.append(ToOpencv())

    # Gaussian heatmaps
    if 'heatmaps2D' in aug_names:
        augmentors.append(Heatmaps(data_config.database.num_landmarks, data_config.ftmap_size,
                                   data_config.sigma2D, norm=data_config.heatmap2D_norm))

    if 'boundaries' in aug_names:
        augmentors.append(AddBoundary(num_landmarks=data_config.database.num_landmarks,
                                      map_size=data_config.ftmap_size,
                                      sigma=data_config.sigmaBD))
    # Pose generator
    if data_config.generate_pose:
        augmentors.append(PositPose(data_config.database.ldm_ids,
                                    focal_ratio=data_config.focal_ratio,
                                    selected_ids=data_config.posit_ids,
                                    max_iter=data_config.posit_max_iter))

    return augmentors


class ToOpencv:
    def __call__(self, sample):
        # Convert in a numpy array and change to GBR
        image = np.array(sample['image'])
        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
        sample['image'] = image
        return sample


class TargetCrop(TargetCropAug):
    def __init__(self, crop_size=256, target_dist=1.6):
        super(TargetCrop, self).__init__(crop_size, crop_size, target_dist)


class AddModel3D(PositPose):
    def __init__(self, ldm_ids, ftmap_size=(256, 256), focal_ratio=1.5, totensor=False):
        super(AddModel3D, self).__init__(ldm_ids, focal_ratio=focal_ratio)
        img_bbox = [0, 0, ftmap_size[1], ftmap_size[0]]  # Shapes given are inverted (y,x)
        self.cam_matrix = self._camera_matrix(img_bbox)

        if totensor:
            self.cam_matrix = torch.tensor(self.cam_matrix, dtype=torch.float)
            self.model3d_world = torch.tensor(self.model3d_world, dtype=torch.float)

    def __call__(self, sample={}):
        # Save intrinsic matrix and 3D model landmarks
        sample['cam_matrix'] = self.cam_matrix
        sample['model3d'] = self.model3d_world
        return sample