svjack's picture
Upload folder using huggingface_hub
d015578 verified
import os
import numpy as np
import json
from collections import OrderedDict
from scipy.integrate import simps
from spiga.data.loaders.dl_config import db_anns_path
from spiga.eval.benchmark.metrics.metrics import Metrics
class MetricsLandmarks(Metrics):
def __init__(self, name='landmarks'):
super().__init__(name)
self.db_info = None
self.nme_norm = "corners"
self.nme_thr = 8
self.percentile = [90, 95, 99]
# Cumulative plot axis length
self.bins = 10000
def compute_error(self, data_anns, data_pred, database, select_ids=None):
# Initialize global logs and variables of Computer Error function
self.init_ce(data_anns, data_pred, database)
self._update_lnd_param()
# Order data and compute nme
self.error['nme_per_img'] = []
self.error['ne_per_img'] = OrderedDict()
self.error['ne_per_ldm'] = OrderedDict()
for img_id, anns in enumerate(data_anns):
# Init variables per img
pred = data_pred[img_id]
# Get select ids to compute
if select_ids is None:
selected_ldm = anns['ids']
else:
selected_ldm = list(set(select_ids) & set(anns['ids']))
norm = self._get_img_norm(anns)
for ldm_id in selected_ldm:
# Compute Normalize Error
anns_ldm = self._get_lnd_from_id(anns, ldm_id)
pred_ldm = self._get_lnd_from_id(pred, ldm_id)
ne = self._dist_l2(anns_ldm, pred_ldm)/norm * 100
self.error['ne_per_img'].setdefault(img_id, []).append(ne)
self.error['ne_per_ldm'].setdefault(ldm_id, []).append(ne)
# NME per image
if self.database in ['merlrav']:
# LUVLI at MERLRAV divide by 68 despite the annotated landmarks in the image.
self.error['nme_per_img'].append(np.sum(self.error['ne_per_img'][img_id])/68)
else:
self.error['nme_per_img'].append(np.mean(self.error['ne_per_img'][img_id]))
# Cumulative NME
self.error['cumulative_nme'] = self._cumulative_error(self.error['nme_per_img'], bins=self.bins)
return self.error
def metrics(self):
# Initialize global logs and variables of Metrics function
self.init_metrics()
# Basic metrics (NME/NMPE/AUC/FR) for full dataset
nme, nmpe, auc, fr, _, _ = self._basic_metrics()
print('NME: %.3f' % nme)
self.metrics_log['nme'] = nme
for percent_id, percentile in enumerate(self.percentile):
print('NME_P%i: %.3f' % (percentile, nmpe[percent_id]))
self.metrics_log['nme_p%i' % percentile] = nmpe[percent_id]
self.metrics_log['nme_thr'] = self.nme_thr
self.metrics_log['nme_norm'] = self.nme_norm
print('AUC_%i: %.3f' % (self.nme_thr, auc))
self.metrics_log['auc'] = auc
print('FR_%i: %.3f' % (self.nme_thr, fr))
self.metrics_log['fr'] = fr
# Subset basic metrics
subsets = self.db_info['test_subsets']
if self.data_type == 'test' and len(subsets) > 0:
self.metrics_log['subset'] = OrderedDict()
for subset, img_filter in subsets.items():
self.metrics_log['subset'][subset] = OrderedDict()
nme, nmpe, auc, fr, _, _ = self._basic_metrics(img_select=img_filter)
print('> Landmarks subset: %s' % subset.upper())
print('NME: %.3f' % nme)
self.metrics_log['subset'][subset]['nme'] = nme
for percent_id, percentile in enumerate(self.percentile):
print('NME_P%i: %.3f' % (percentile, nmpe[percent_id]))
self.metrics_log['subset'][subset]['nme_p%i' % percentile] = nmpe[percent_id]
print('AUC_%i: %.3f' % (self.nme_thr, auc))
self.metrics_log['subset'][subset]['auc'] = auc
print('FR_%i: %.3f' % (self.nme_thr, fr))
self.metrics_log['subset'][subset]['fr'] = fr
# NME/NPE per landmark
self.metrics_log['nme_per_ldm'] = OrderedDict()
for percentile in self.percentile:
self.metrics_log['npe%i_per_ldm' % percentile] = OrderedDict()
for k, v in self.error['ne_per_ldm'].items():
self.metrics_log['nme_per_ldm'][k] = np.mean(v)
for percentile in self.percentile:
self.metrics_log['npe%i_per_ldm' % percentile][k] = np.percentile(v, percentile)
return self.metrics_log
def get_pimg_err(self, data_dict=None, img_select=None):
data = self.error['nme_per_img']
if img_select is not None:
data = [data[img_id] for img_id in img_select]
name_dict = self.name + '/nme'
if data_dict is not None:
data_dict[name_dict] = data
else:
data_dict = data
return data_dict
def _update_lnd_param(self):
db_info_file = db_anns_path.format(database=self.database, file_name='db_info')
if os.path.exists(db_info_file):
with open(db_info_file) as jsonfile:
self.db_info = json.load(jsonfile)
norm_dict = self.db_info['norm']
nme_norm, nme_thr = next(iter(norm_dict.items()))
print('Default landmarks configuration: \n %s: %i' % (nme_norm, nme_thr))
answer = input("Change default config? (N/Y) >>> ")
if answer.lower() in ['yes', 'y']:
answer = input("Normalization options: %s >>> " % str(list(norm_dict.keys())))
if answer in norm_dict.keys():
nme_norm = answer
nme_thr = norm_dict[nme_norm]
else:
print("Option %s not available keep in default one: %s" % (answer, nme_norm))
answer = input("Change threshold ->%s:%i ? (N/Y) >>> " % (nme_norm, nme_thr))
if answer.lower() in ['yes', 'y']:
answer = input('NME threshold: >>> ')
nme_thr = float(answer)
else:
print("Keeping default threshold: %i" % nme_thr)
self.nme_norm = nme_norm
self.nme_thr = nme_thr
else:
raise ValueError('Database %s specifics not defined. Missing db_info.json' % self.database)
def _dist_l2(self, pointA, pointB):
return float(((pointA - pointB) ** 2).sum() ** 0.5)
def _get_lnd_from_id(self, anns, ids):
idx = anns['ids'].index(ids)
ref = np.array(anns['landmarks'][idx])
return ref
def _get_img_norm(self, anns):
if self.nme_norm == 'pupils':
print('WARNING: Pupils norm only implemented for 68 landmark configuration')
left_eye = [7, 138, 139, 8, 141, 142]
right_eye = [11, 144, 145, 12, 147, 148]
refA = np.zeros(2)
refB = np.zeros(2)
for i in range(len(left_eye)):
refA += self._get_lnd_from_id(anns, left_eye[i])
refB += self._get_lnd_from_id(anns, right_eye[i])
refA = refA/len(left_eye) # Left
refB = refB/len(right_eye) # Right
elif self.nme_norm == 'corners':
refA = self._get_lnd_from_id(anns, 12) # Left
refB = self._get_lnd_from_id(anns, 7) # Right
elif self.nme_norm == 'diagonal':
refA = anns['bbox'][0:2]
refB = refA + anns['bbox'][2:4]
elif self.nme_norm == 'height':
return anns['bbox'][3]
elif self.nme_norm == 'lnd_bbox':
lnd = np.array(anns['landmarks'])
lnd_max = np.max(lnd, axis=0)
lnd_min = np.min(lnd, axis=0)
lnd_wh = lnd_max - lnd_min
return (lnd_wh[0]*lnd_wh[1])**0.5
elif self.nme_norm == 'bbox':
return (anns['bbox'][2] * anns['bbox'][3]) ** 0.5
else:
raise ValueError('Normalization %s not implemented' % self.nme_norm)
return self._dist_l2(refA, refB)
def _cumulative_error(self, error, bins=10000):
num_imgs, base = np.histogram(error, bins=bins)
cumulative = [x / float(len(error)) for x in np.cumsum(num_imgs)]
base = base[:bins]
cumulative, base = self._filter_cumulative(cumulative, base)
return [cumulative, base]
def _filter_cumulative(self, cumulative, base):
base = [x for x in base if (x < self.nme_thr)]
cumulative = cumulative[:len(base)]
return cumulative, base
def _basic_metrics(self, img_select=None):
data = self.error['nme_per_img']
if img_select is not None:
data = [data[img_id] for img_id in img_select]
[cumulative, base] = self._cumulative_error(data, bins=self.bins)
else:
[cumulative, base] = self.error['cumulative_nme']
# Normalize Mean Error across img
nme = np.mean(data)
# Normalize Mean Percentile Error across img
nmpe = []
for percentile in self.percentile:
nmpe.append(np.percentile(data, percentile))
# Area Under Curve and Failure Rate
auc, fr = self._auc_fr_metrics(cumulative, base)
return nme, nmpe, auc, fr, cumulative, base
def _auc_fr_metrics(self, cumulative, base):
if not base:
auc = 0.
fr = 100.
else:
auc = (simps(cumulative, x=base) / self.nme_thr) * 100.0
if base[-1] < self.nme_thr and cumulative[-1] == 1:
auc += ((self.nme_thr - base[-1]) / self.nme_thr) * 100
fr = (1 - cumulative[-1]) * 100.0
return auc, fr