File size: 1,620 Bytes
17cd746
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import numpy as np

class NME:
    def __init__(self, nme_left_index, nme_right_index):
        self.nme_left_index = nme_left_index
        self.nme_right_index = nme_right_index

    def __repr__(self):
        return "NME()"

    def get_norm_distance(self, landmarks):
        assert isinstance(self.nme_right_index, list), 'the nme_right_index is not list.'
        assert isinstance(self.nme_left_index, list), 'the nme_left, index is not list.'
        right_pupil = landmarks[self.nme_right_index, :].mean(0)
        left_pupil = landmarks[self.nme_left_index, :].mean(0)
        norm_distance = np.linalg.norm(right_pupil - left_pupil)
        return norm_distance

    def test(self, label_pd, label_gt):
        nme_list = []
        label_pd = label_pd.data.cpu().numpy()
        label_gt = label_gt.data.cpu().numpy()

        for i in range(label_gt.shape[0]):
            landmarks_gt = label_gt[i]
            landmarks_pv = label_pd[i]
            if isinstance(self.nme_right_index, list):
                norm_distance = self.get_norm_distance(landmarks_gt)
            elif isinstance(self.nme_right_index, int):
                norm_distance = np.linalg.norm(landmarks_gt[self.nme_left_index] - landmarks_gt[self.nme_right_index])
            else:
                raise NotImplementedError
            landmarks_delta = landmarks_pv - landmarks_gt
            nme = (np.linalg.norm(landmarks_delta, axis=1) / norm_distance).mean()
            nme_list.append(nme)
            # sum_nme += nme
            # total_cnt += 1
        return nme_list