Spaces:
Running
Running
import os | |
import numpy as np | |
import json | |
from collections import OrderedDict | |
from scipy.integrate import simps | |
from spiga.data.loaders.dl_config import db_anns_path | |
from spiga.eval.benchmark.metrics.metrics import Metrics | |
class MetricsLandmarks(Metrics): | |
def __init__(self, name='landmarks'): | |
super().__init__(name) | |
self.db_info = None | |
self.nme_norm = "corners" | |
self.nme_thr = 8 | |
self.percentile = [90, 95, 99] | |
# Cumulative plot axis length | |
self.bins = 10000 | |
def compute_error(self, data_anns, data_pred, database, select_ids=None): | |
# Initialize global logs and variables of Computer Error function | |
self.init_ce(data_anns, data_pred, database) | |
self._update_lnd_param() | |
# Order data and compute nme | |
self.error['nme_per_img'] = [] | |
self.error['ne_per_img'] = OrderedDict() | |
self.error['ne_per_ldm'] = OrderedDict() | |
for img_id, anns in enumerate(data_anns): | |
# Init variables per img | |
pred = data_pred[img_id] | |
# Get select ids to compute | |
if select_ids is None: | |
selected_ldm = anns['ids'] | |
else: | |
selected_ldm = list(set(select_ids) & set(anns['ids'])) | |
norm = self._get_img_norm(anns) | |
for ldm_id in selected_ldm: | |
# Compute Normalize Error | |
anns_ldm = self._get_lnd_from_id(anns, ldm_id) | |
pred_ldm = self._get_lnd_from_id(pred, ldm_id) | |
ne = self._dist_l2(anns_ldm, pred_ldm)/norm * 100 | |
self.error['ne_per_img'].setdefault(img_id, []).append(ne) | |
self.error['ne_per_ldm'].setdefault(ldm_id, []).append(ne) | |
# NME per image | |
if self.database in ['merlrav']: | |
# LUVLI at MERLRAV divide by 68 despite the annotated landmarks in the image. | |
self.error['nme_per_img'].append(np.sum(self.error['ne_per_img'][img_id])/68) | |
else: | |
self.error['nme_per_img'].append(np.mean(self.error['ne_per_img'][img_id])) | |
# Cumulative NME | |
self.error['cumulative_nme'] = self._cumulative_error(self.error['nme_per_img'], bins=self.bins) | |
return self.error | |
def metrics(self): | |
# Initialize global logs and variables of Metrics function | |
self.init_metrics() | |
# Basic metrics (NME/NMPE/AUC/FR) for full dataset | |
nme, nmpe, auc, fr, _, _ = self._basic_metrics() | |
print('NME: %.3f' % nme) | |
self.metrics_log['nme'] = nme | |
for percent_id, percentile in enumerate(self.percentile): | |
print('NME_P%i: %.3f' % (percentile, nmpe[percent_id])) | |
self.metrics_log['nme_p%i' % percentile] = nmpe[percent_id] | |
self.metrics_log['nme_thr'] = self.nme_thr | |
self.metrics_log['nme_norm'] = self.nme_norm | |
print('AUC_%i: %.3f' % (self.nme_thr, auc)) | |
self.metrics_log['auc'] = auc | |
print('FR_%i: %.3f' % (self.nme_thr, fr)) | |
self.metrics_log['fr'] = fr | |
# Subset basic metrics | |
subsets = self.db_info['test_subsets'] | |
if self.data_type == 'test' and len(subsets) > 0: | |
self.metrics_log['subset'] = OrderedDict() | |
for subset, img_filter in subsets.items(): | |
self.metrics_log['subset'][subset] = OrderedDict() | |
nme, nmpe, auc, fr, _, _ = self._basic_metrics(img_select=img_filter) | |
print('> Landmarks subset: %s' % subset.upper()) | |
print('NME: %.3f' % nme) | |
self.metrics_log['subset'][subset]['nme'] = nme | |
for percent_id, percentile in enumerate(self.percentile): | |
print('NME_P%i: %.3f' % (percentile, nmpe[percent_id])) | |
self.metrics_log['subset'][subset]['nme_p%i' % percentile] = nmpe[percent_id] | |
print('AUC_%i: %.3f' % (self.nme_thr, auc)) | |
self.metrics_log['subset'][subset]['auc'] = auc | |
print('FR_%i: %.3f' % (self.nme_thr, fr)) | |
self.metrics_log['subset'][subset]['fr'] = fr | |
# NME/NPE per landmark | |
self.metrics_log['nme_per_ldm'] = OrderedDict() | |
for percentile in self.percentile: | |
self.metrics_log['npe%i_per_ldm' % percentile] = OrderedDict() | |
for k, v in self.error['ne_per_ldm'].items(): | |
self.metrics_log['nme_per_ldm'][k] = np.mean(v) | |
for percentile in self.percentile: | |
self.metrics_log['npe%i_per_ldm' % percentile][k] = np.percentile(v, percentile) | |
return self.metrics_log | |
def get_pimg_err(self, data_dict=None, img_select=None): | |
data = self.error['nme_per_img'] | |
if img_select is not None: | |
data = [data[img_id] for img_id in img_select] | |
name_dict = self.name + '/nme' | |
if data_dict is not None: | |
data_dict[name_dict] = data | |
else: | |
data_dict = data | |
return data_dict | |
def _update_lnd_param(self): | |
db_info_file = db_anns_path.format(database=self.database, file_name='db_info') | |
if os.path.exists(db_info_file): | |
with open(db_info_file) as jsonfile: | |
self.db_info = json.load(jsonfile) | |
norm_dict = self.db_info['norm'] | |
nme_norm, nme_thr = next(iter(norm_dict.items())) | |
print('Default landmarks configuration: \n %s: %i' % (nme_norm, nme_thr)) | |
answer = input("Change default config? (N/Y) >>> ") | |
if answer.lower() in ['yes', 'y']: | |
answer = input("Normalization options: %s >>> " % str(list(norm_dict.keys()))) | |
if answer in norm_dict.keys(): | |
nme_norm = answer | |
nme_thr = norm_dict[nme_norm] | |
else: | |
print("Option %s not available keep in default one: %s" % (answer, nme_norm)) | |
answer = input("Change threshold ->%s:%i ? (N/Y) >>> " % (nme_norm, nme_thr)) | |
if answer.lower() in ['yes', 'y']: | |
answer = input('NME threshold: >>> ') | |
nme_thr = float(answer) | |
else: | |
print("Keeping default threshold: %i" % nme_thr) | |
self.nme_norm = nme_norm | |
self.nme_thr = nme_thr | |
else: | |
raise ValueError('Database %s specifics not defined. Missing db_info.json' % self.database) | |
def _dist_l2(self, pointA, pointB): | |
return float(((pointA - pointB) ** 2).sum() ** 0.5) | |
def _get_lnd_from_id(self, anns, ids): | |
idx = anns['ids'].index(ids) | |
ref = np.array(anns['landmarks'][idx]) | |
return ref | |
def _get_img_norm(self, anns): | |
if self.nme_norm == 'pupils': | |
print('WARNING: Pupils norm only implemented for 68 landmark configuration') | |
left_eye = [7, 138, 139, 8, 141, 142] | |
right_eye = [11, 144, 145, 12, 147, 148] | |
refA = np.zeros(2) | |
refB = np.zeros(2) | |
for i in range(len(left_eye)): | |
refA += self._get_lnd_from_id(anns, left_eye[i]) | |
refB += self._get_lnd_from_id(anns, right_eye[i]) | |
refA = refA/len(left_eye) # Left | |
refB = refB/len(right_eye) # Right | |
elif self.nme_norm == 'corners': | |
refA = self._get_lnd_from_id(anns, 12) # Left | |
refB = self._get_lnd_from_id(anns, 7) # Right | |
elif self.nme_norm == 'diagonal': | |
refA = anns['bbox'][0:2] | |
refB = refA + anns['bbox'][2:4] | |
elif self.nme_norm == 'height': | |
return anns['bbox'][3] | |
elif self.nme_norm == 'lnd_bbox': | |
lnd = np.array(anns['landmarks']) | |
lnd_max = np.max(lnd, axis=0) | |
lnd_min = np.min(lnd, axis=0) | |
lnd_wh = lnd_max - lnd_min | |
return (lnd_wh[0]*lnd_wh[1])**0.5 | |
elif self.nme_norm == 'bbox': | |
return (anns['bbox'][2] * anns['bbox'][3]) ** 0.5 | |
else: | |
raise ValueError('Normalization %s not implemented' % self.nme_norm) | |
return self._dist_l2(refA, refB) | |
def _cumulative_error(self, error, bins=10000): | |
num_imgs, base = np.histogram(error, bins=bins) | |
cumulative = [x / float(len(error)) for x in np.cumsum(num_imgs)] | |
base = base[:bins] | |
cumulative, base = self._filter_cumulative(cumulative, base) | |
return [cumulative, base] | |
def _filter_cumulative(self, cumulative, base): | |
base = [x for x in base if (x < self.nme_thr)] | |
cumulative = cumulative[:len(base)] | |
return cumulative, base | |
def _basic_metrics(self, img_select=None): | |
data = self.error['nme_per_img'] | |
if img_select is not None: | |
data = [data[img_id] for img_id in img_select] | |
[cumulative, base] = self._cumulative_error(data, bins=self.bins) | |
else: | |
[cumulative, base] = self.error['cumulative_nme'] | |
# Normalize Mean Error across img | |
nme = np.mean(data) | |
# Normalize Mean Percentile Error across img | |
nmpe = [] | |
for percentile in self.percentile: | |
nmpe.append(np.percentile(data, percentile)) | |
# Area Under Curve and Failure Rate | |
auc, fr = self._auc_fr_metrics(cumulative, base) | |
return nme, nmpe, auc, fr, cumulative, base | |
def _auc_fr_metrics(self, cumulative, base): | |
if not base: | |
auc = 0. | |
fr = 100. | |
else: | |
auc = (simps(cumulative, x=base) / self.nme_thr) * 100.0 | |
if base[-1] < self.nme_thr and cumulative[-1] == 1: | |
auc += ((self.nme_thr - base[-1]) / self.nme_thr) * 100 | |
fr = (1 - cumulative[-1]) * 100.0 | |
return auc, fr | |