File size: 7,711 Bytes
d015578
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import os
import pkg_resources
import numpy as np
import cv2

# My libs
from spiga.data.loaders.augmentors.utils import rotation_matrix_to_euler

# Model file nomenclature
model_file_dft = pkg_resources.resource_filename('spiga', 'data/models3D') + '/mean_face_3D_{num_ldm}.txt'


class PositPose:

    def __init__(self, ldm_ids, focal_ratio=1, selected_ids=None, max_iter=100,
                  fix_bbox=True, model_file=model_file_dft):

        # Load 3D face model
        model3d_world, model3d_ids = self._load_world_shape(ldm_ids, model_file)

        # Generate id mask to pick only the robust landmarks for posit
        if selected_ids is None:
            model3d_mask = np.ones(len(ldm_ids))
        else:
            model3d_mask = np.zeros(len(ldm_ids))
            for index, posit_id in enumerate(model3d_ids):
                if posit_id in selected_ids:
                    model3d_mask[index] = 1

        self.ldm_ids = ldm_ids                  # Ids from the database
        self.model3d_world = model3d_world      # Model data
        self.model3d_ids = model3d_ids          # Model ids
        self.model3d_mask = model3d_mask        # Model mask ids
        self.max_iter = max_iter                # Refinement iterations
        self.focal_ratio = focal_ratio          # Camera matrix focal length ratio
        self.fix_bbox = fix_bbox                # Camera matrix centered on image (False to centered on bbox)

    def __call__(self, sample):

        landmarks = sample['landmarks']
        mask = sample['mask_ldm']

        # Camera matrix
        img_shape = np.array(sample['image'].shape)[0:2]
        if 'img2map_scale' in sample.keys():
            img_shape = img_shape * sample['img2map_scale']

        if self.fix_bbox:
            img_bbox = [0, 0, img_shape[1], img_shape[0]]   # Shapes given are inverted (y,x)
            cam_matrix = self._camera_matrix(img_bbox)
        else:
            bbox = sample['bbox']   # Scale error when ftshape and img_shape mismatch
            cam_matrix = self._camera_matrix(bbox)

        # Save intrinsic matrix and 3D model landmarks
        sample['cam_matrix'] = cam_matrix
        sample['model3d'] = self.model3d_world

        world_pts, image_pts = self._set_correspondences(landmarks, mask)

        if image_pts.shape[0] < 4:
            print('POSIT does not work without landmarks')
            rot_matrix, trl_matrix = np.eye(3, dtype=float), np.array([0, 0, 0])
        else:
            rot_matrix, trl_matrix = self._modern_posit(world_pts, image_pts, cam_matrix)

        euler = rotation_matrix_to_euler(rot_matrix)
        sample['pose'] = np.array([euler[0], euler[1], euler[2], trl_matrix[0], trl_matrix[1], trl_matrix[2]])
        sample['model3d_proj'] = self._project_points(rot_matrix, trl_matrix, cam_matrix, norm=img_shape)
        return sample

    def _load_world_shape(self, ldm_ids, model_file):
        return load_world_shape(ldm_ids, model_file=model_file)

    def _camera_matrix(self, bbox):
        focal_length_x = bbox[2] * self.focal_ratio
        focal_length_y = bbox[3] * self.focal_ratio
        face_center = (bbox[0] + (bbox[2] * 0.5)), (bbox[1] + (bbox[3] * 0.5))

        cam_matrix = np.array([[focal_length_x, 0, face_center[0]],
                               [0, focal_length_y, face_center[1]],
                               [0, 0, 1]])
        return cam_matrix

    def _set_correspondences(self, landmarks, mask):
        # Correspondences using labelled and robust landmarks
        img_mask = np.logical_and(mask, self.model3d_mask)
        img_mask = img_mask.astype(bool)

        image_pts = landmarks[img_mask]
        world_pts = self.model3d_world[img_mask]
        return world_pts, image_pts

    def _modern_posit(self, world_pts, image_pts, cam_matrix):
        return modern_posit(world_pts, image_pts, cam_matrix, self.max_iter)

    def _project_points(self, rot, trl, cam_matrix, norm=None):
        # Perspective projection model
        trl = np.expand_dims(trl, 1)
        extrinsics = np.concatenate((rot, trl), 1)
        proj_matrix = np.matmul(cam_matrix, extrinsics)

        # Homogeneous landmarks
        pts = self.model3d_world
        ones = np.ones(pts.shape[0])
        ones = np.expand_dims(ones, 1)
        pts_hom = np.concatenate((pts, ones), 1)

        # Project landmarks
        pts_proj = np.matmul(proj_matrix, pts_hom.T).T
        pts_proj = pts_proj / np.expand_dims(pts_proj[:, 2], 1) # Lambda = 1

        if norm is not None:
            pts_proj[:, 0] /= norm[0]
            pts_proj[:, 1] /= norm[1]
        return pts_proj[:, :-1]


def load_world_shape(db_landmarks, model_file=model_file_dft):

    # Load 3D mean face coordinates
    num_ldm = len(db_landmarks)
    filename = model_file.format(num_ldm=num_ldm)
    if not os.path.exists(filename):
        raise ValueError('No 3D model find for %i landmarks' % num_ldm)

    posit_landmarks = np.genfromtxt(filename, delimiter='|', dtype=int, usecols=0).tolist()
    mean_face_3D = np.genfromtxt(filename, delimiter='|', dtype=(float, float, float), usecols=(1, 2, 3)).tolist()
    world_all = len(mean_face_3D)*[None]
    index_all = len(mean_face_3D)*[None]

    for cont, elem in enumerate(mean_face_3D):
        pt3d = [elem[2], -elem[0], -elem[1]]
        lnd_idx = db_landmarks.index(posit_landmarks[cont])
        world_all[lnd_idx] = pt3d
        index_all[lnd_idx] = posit_landmarks[cont]

    return np.array(world_all), np.array(index_all)


def modern_posit(world_pts, image_pts, cam_matrix, max_iters):
    # Homogeneous world points
    num_landmarks = image_pts.shape[0]
    one = np.ones((num_landmarks, 1))
    A = np.concatenate((world_pts, one), axis=1)
    B = np.linalg.pinv(A)

    # Normalize image points
    focal_length = cam_matrix[0,0]
    img_center = (cam_matrix[0,2], cam_matrix[1,2])
    centered_pts = np.zeros((num_landmarks,2))
    centered_pts[:,0] = (image_pts[:,0]-img_center[0])/focal_length
    centered_pts[:,1] = (image_pts[:,1]-img_center[1])/focal_length
    Ui = centered_pts[:,0]
    Vi = centered_pts[:,1]

    # POSIT loop
    Tx, Ty, Tz = 0.0, 0.0, 0.0
    r1, r2, r3 = [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]
    for iter in range(0, max_iters):
        I = np.dot(B,Ui)
        J = np.dot(B,Vi)

        # Estimate translation vector and rotation matrix
        normI = 1.0 / np.sqrt(I[0] * I[0] + I[1] * I[1] + I[2] * I[2])
        normJ = 1.0 / np.sqrt(J[0] * J[0] + J[1] * J[1] + J[2] * J[2])
        Tz = np.sqrt(normI * normJ)  # geometric average instead of arithmetic average of classicPosit
        r1N = I*Tz
        r2N = J*Tz
        r1 = r1N[0:3]
        r2 = r2N[0:3]
        r1 = np.clip(r1, -1, 1)
        r2 = np.clip(r2, -1, 1)
        r3 = np.cross(r1,r2)
        r3T = np.concatenate((r3, [Tz]), axis=0)
        Tx = r1N[3]
        Ty = r2N[3]

        # Compute epsilon, update Ui and Vi and check convergence
        eps = np.dot(A, r3T)/Tz
        oldUi = Ui
        oldVi = Vi
        Ui = np.multiply(eps, centered_pts[:,0])
        Vi = np.multiply(eps, centered_pts[:,1])
        deltaUi = Ui - oldUi
        deltaVi = Vi - oldVi
        delta = focal_length * focal_length * (np.dot(np.transpose(deltaUi), deltaUi) + np.dot(np.transpose(deltaVi), deltaVi))
        if iter > 0 and delta < 0.01:  # converged
            break

    rot_matrix = np.array([np.transpose(r1), np.transpose(r2), np.transpose(r3)])
    trl_matrix = np.array([Tx, Ty, Tz])
    # Convert to the nearest orthogonal rotation matrix
    w, u, vt = cv2.SVDecomp(rot_matrix)  # R = U*D*Vt
    rot_matrix = np.matmul(np.matmul(u, np.eye(3, dtype=float)), vt)
    return rot_matrix, trl_matrix