File size: 9,779 Bytes
d015578
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import os
import numpy as np
import json
from collections import OrderedDict
from scipy.integrate import simps

from spiga.data.loaders.dl_config import db_anns_path
from spiga.eval.benchmark.metrics.metrics import Metrics


class MetricsLandmarks(Metrics):

    def __init__(self, name='landmarks'):
        super().__init__(name)

        self.db_info = None
        self.nme_norm = "corners"
        self.nme_thr = 8
        self.percentile = [90, 95, 99]
        # Cumulative plot axis length
        self.bins = 10000

    def compute_error(self, data_anns, data_pred, database, select_ids=None):

        # Initialize global logs and variables of Computer Error function
        self.init_ce(data_anns, data_pred, database)
        self._update_lnd_param()

        # Order data and compute nme
        self.error['nme_per_img'] = []
        self.error['ne_per_img'] = OrderedDict()
        self.error['ne_per_ldm'] = OrderedDict()
        for img_id, anns in enumerate(data_anns):
            # Init variables per img
            pred = data_pred[img_id]

            # Get select ids to compute
            if select_ids is None:
                selected_ldm = anns['ids']
            else:
                selected_ldm = list(set(select_ids) & set(anns['ids']))

            norm = self._get_img_norm(anns)
            for ldm_id in selected_ldm:
                # Compute Normalize Error
                anns_ldm = self._get_lnd_from_id(anns, ldm_id)
                pred_ldm = self._get_lnd_from_id(pred, ldm_id)
                ne = self._dist_l2(anns_ldm, pred_ldm)/norm * 100
                self.error['ne_per_img'].setdefault(img_id, []).append(ne)
                self.error['ne_per_ldm'].setdefault(ldm_id, []).append(ne)

            # NME per image
            if self.database in ['merlrav']:
                # LUVLI at MERLRAV divide by 68 despite the annotated landmarks in the image.
                self.error['nme_per_img'].append(np.sum(self.error['ne_per_img'][img_id])/68)
            else:
                self.error['nme_per_img'].append(np.mean(self.error['ne_per_img'][img_id]))

        # Cumulative NME
        self.error['cumulative_nme'] = self._cumulative_error(self.error['nme_per_img'], bins=self.bins)

        return self.error

    def metrics(self):

        # Initialize global logs and variables of Metrics function
        self.init_metrics()

        # Basic metrics (NME/NMPE/AUC/FR) for full dataset
        nme, nmpe, auc, fr, _, _ = self._basic_metrics()

        print('NME: %.3f' % nme)
        self.metrics_log['nme'] = nme
        for percent_id, percentile in enumerate(self.percentile):
            print('NME_P%i: %.3f' % (percentile, nmpe[percent_id]))
            self.metrics_log['nme_p%i' % percentile] = nmpe[percent_id]
        self.metrics_log['nme_thr'] = self.nme_thr
        self.metrics_log['nme_norm'] = self.nme_norm
        print('AUC_%i: %.3f' % (self.nme_thr, auc))
        self.metrics_log['auc'] = auc
        print('FR_%i: %.3f' % (self.nme_thr, fr))
        self.metrics_log['fr'] = fr

        # Subset basic metrics
        subsets = self.db_info['test_subsets']
        if self.data_type == 'test' and len(subsets) > 0:
            self.metrics_log['subset'] = OrderedDict()
            for subset, img_filter in subsets.items():
                self.metrics_log['subset'][subset] = OrderedDict()
                nme, nmpe, auc, fr, _, _ = self._basic_metrics(img_select=img_filter)
                print('> Landmarks subset: %s' % subset.upper())
                print('NME: %.3f' % nme)
                self.metrics_log['subset'][subset]['nme'] = nme
                for percent_id, percentile in enumerate(self.percentile):
                    print('NME_P%i: %.3f' % (percentile, nmpe[percent_id]))
                    self.metrics_log['subset'][subset]['nme_p%i' % percentile] = nmpe[percent_id]
                print('AUC_%i: %.3f' % (self.nme_thr, auc))
                self.metrics_log['subset'][subset]['auc'] = auc
                print('FR_%i: %.3f' % (self.nme_thr, fr))
                self.metrics_log['subset'][subset]['fr'] = fr

        # NME/NPE per landmark
        self.metrics_log['nme_per_ldm'] = OrderedDict()
        for percentile in self.percentile:
            self.metrics_log['npe%i_per_ldm' % percentile] = OrderedDict()
        for k, v in self.error['ne_per_ldm'].items():
            self.metrics_log['nme_per_ldm'][k] = np.mean(v)
            for percentile in self.percentile:
                self.metrics_log['npe%i_per_ldm' % percentile][k] = np.percentile(v, percentile)

        return self.metrics_log

    def get_pimg_err(self, data_dict=None, img_select=None):
        data = self.error['nme_per_img']
        if img_select is not None:
            data = [data[img_id] for img_id in img_select]
        name_dict = self.name + '/nme'
        if data_dict is not None:
            data_dict[name_dict] = data
        else:
            data_dict = data
        return data_dict

    def _update_lnd_param(self):
        db_info_file = db_anns_path.format(database=self.database, file_name='db_info')
        if os.path.exists(db_info_file):
            with open(db_info_file) as jsonfile:
                self.db_info = json.load(jsonfile)

            norm_dict = self.db_info['norm']
            nme_norm, nme_thr = next(iter(norm_dict.items()))
            print('Default landmarks configuration: \n %s: %i' % (nme_norm, nme_thr))
            answer = input("Change default config? (N/Y) >>> ")
            if answer.lower() in ['yes', 'y']:
                answer = input("Normalization options: %s >>> " % str(list(norm_dict.keys())))
                if answer in norm_dict.keys():
                    nme_norm = answer
                    nme_thr = norm_dict[nme_norm]
                else:
                    print("Option %s not available keep in default one: %s" % (answer, nme_norm))
                answer = input("Change threshold ->%s:%i ? (N/Y) >>> " % (nme_norm, nme_thr))
                if answer.lower() in ['yes', 'y']:
                    answer = input('NME threshold: >>> ')
                    nme_thr = float(answer)
                else:
                    print("Keeping default threshold: %i" % nme_thr)

            self.nme_norm = nme_norm
            self.nme_thr = nme_thr

        else:
            raise ValueError('Database %s specifics not defined. Missing db_info.json' % self.database)

    def _dist_l2(self, pointA, pointB):
        return float(((pointA - pointB) ** 2).sum() ** 0.5)

    def _get_lnd_from_id(self, anns, ids):
        idx = anns['ids'].index(ids)
        ref = np.array(anns['landmarks'][idx])
        return ref

    def _get_img_norm(self, anns):
        if self.nme_norm == 'pupils':
            print('WARNING: Pupils norm only implemented for 68 landmark configuration')
            left_eye = [7, 138, 139, 8, 141, 142]
            right_eye = [11, 144, 145, 12, 147, 148]
            refA = np.zeros(2)
            refB = np.zeros(2)
            for i in range(len(left_eye)):
                refA += self._get_lnd_from_id(anns, left_eye[i])
                refB += self._get_lnd_from_id(anns, right_eye[i])
            refA = refA/len(left_eye)   # Left
            refB = refB/len(right_eye)  # Right
        elif self.nme_norm == 'corners':
            refA = self._get_lnd_from_id(anns, 12)  # Left
            refB = self._get_lnd_from_id(anns, 7)  # Right
        elif self.nme_norm == 'diagonal':
            refA = anns['bbox'][0:2]
            refB = refA + anns['bbox'][2:4]
        elif self.nme_norm == 'height':
            return anns['bbox'][3]
        elif self.nme_norm == 'lnd_bbox':
            lnd = np.array(anns['landmarks'])
            lnd_max = np.max(lnd, axis=0)
            lnd_min = np.min(lnd, axis=0)
            lnd_wh = lnd_max - lnd_min
            return (lnd_wh[0]*lnd_wh[1])**0.5
        elif self.nme_norm == 'bbox':
            return (anns['bbox'][2] * anns['bbox'][3]) ** 0.5
        else:
            raise ValueError('Normalization %s not implemented' % self.nme_norm)

        return self._dist_l2(refA, refB)

    def _cumulative_error(self, error, bins=10000):
        num_imgs, base = np.histogram(error, bins=bins)
        cumulative = [x / float(len(error)) for x in np.cumsum(num_imgs)]
        base = base[:bins]
        cumulative, base = self._filter_cumulative(cumulative, base)
        return [cumulative, base]

    def _filter_cumulative(self, cumulative, base):
        base = [x for x in base if (x < self.nme_thr)]
        cumulative = cumulative[:len(base)]
        return cumulative, base

    def _basic_metrics(self, img_select=None):
        data = self.error['nme_per_img']
        if img_select is not None:
            data = [data[img_id] for img_id in img_select]
            [cumulative, base] = self._cumulative_error(data, bins=self.bins)
        else:
            [cumulative, base] = self.error['cumulative_nme']

        # Normalize Mean Error across img
        nme = np.mean(data)
        # Normalize Mean Percentile Error across img
        nmpe = []
        for percentile in self.percentile:
            nmpe.append(np.percentile(data, percentile))

        # Area Under Curve and Failure Rate
        auc, fr = self._auc_fr_metrics(cumulative, base)

        return nme, nmpe, auc, fr, cumulative, base

    def _auc_fr_metrics(self, cumulative, base):
        if not base:
            auc = 0.
            fr = 100.
        else:
            auc = (simps(cumulative, x=base) / self.nme_thr) * 100.0
            if base[-1] < self.nme_thr and cumulative[-1] == 1:
                auc += ((self.nme_thr - base[-1]) / self.nme_thr) * 100
            fr = (1 - cumulative[-1]) * 100.0
        return auc, fr