Spaces:
Running
Running
import os | |
import json | |
import pkg_resources | |
from collections import OrderedDict | |
# Default data paths | |
db_img_path = pkg_resources.resource_filename('spiga', 'data/databases') | |
db_anns_path = pkg_resources.resource_filename('spiga', 'data/annotations') + "/{database}/{file_name}.json" | |
class AlignConfig: | |
def __init__(self, database_name, mode='train'): | |
# Dataset | |
self.database_name = database_name | |
self.working_mode = mode | |
self.database = None # Set at self._update_database() | |
self.anns_file = None # Set at self._update_database() | |
self.image_dir = None # Set at self._update_database() | |
self._update_database() | |
self.image_size = (256, 256) | |
self.ftmap_size = (256, 256) | |
# Dataloaders | |
self.ids = None # List of a subset if need it | |
self.shuffle = True # Shuffle samples | |
self.num_workers = 4 # Threads | |
# Posit | |
self.generate_pose = True # Generate pose parameters from landmarks | |
self.focal_ratio = 1.5 # Camera matrix focal length ratio | |
self.posit_max_iter = 100 # Refinement iterations | |
# Subset of robust ids in the 3D model to use in posit. | |
# 'None' to use all the available model landmarks. | |
self.posit_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, | |
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] | |
# Data augmentation | |
# Control augmentations with the following list, crop to self.img_size is mandatory, check target_dist param. | |
if mode == 'train': | |
self.aug_names = ['flip', 'rotate_scale', 'occlusion', 'lighting', 'blur'] | |
else: | |
self.aug_names = [] | |
self.shuffle = False | |
# Flip | |
self.hflip_prob = 0.5 | |
# Rotation | |
self.angle_range = 45. | |
# Scale | |
self.scale_max = 0.15 | |
self.scale_min = -0.15 | |
# Translation | |
self.trl_ratio = 0.05 # Translation augmentation | |
# Crop target rescale | |
self.target_dist = 1.6 # Target distance zoom in/out around face. Default: 1. | |
# Occlusion | |
self.occluded_max_len = 0.4 | |
self.occluded_min_len = 0.1 | |
self.occluded_covar_ratio = 2.25**0.5 | |
# Lighting | |
self.hsv_range_min = [-0.5, -0.5, -0.5] | |
self.hsv_range_max = [0.5, 0.5, 0.5] | |
# Blur | |
self.blur_prob = 0.5 | |
self.blur_kernel_range = [0, 2] | |
# Heatmaps 2D | |
self.sigma2D = 1.5 | |
self.heatmap2D_norm = False | |
# Boundaries | |
self.sigmaBD = 1 | |
def update(self, params_dict): | |
state_dict = self.state_dict() | |
for k, v in params_dict.items(): | |
if k in state_dict or hasattr(self, k): | |
setattr(self, k, v) | |
else: | |
Warning('Unknown option: {}: {}'.format(k, v)) | |
self._update_database() | |
def state_dict(self, tojson=False): | |
state_dict = OrderedDict() | |
for k in self.__dict__.keys(): | |
if not k.startswith('_'): | |
if tojson and k in ['database']: | |
continue | |
state_dict[k] = getattr(self, k) | |
return state_dict | |
def _update_database(self): | |
self.database = DatabaseStruct(self.database_name) | |
self.anns_file = db_anns_path.format(database=self.database_name, file_name=self.working_mode) | |
self.image_dir = self._get_imgdb_path() | |
def _get_imgdb_path(self): | |
img_dir = None | |
if self.database_name in ['300wpublic', '300wprivate']: | |
img_dir = db_img_path + '/300w/' | |
elif self.database_name in ['aflw19', 'merlrav']: | |
img_dir = db_img_path + '/aflw/data/' | |
elif self.database_name in ['cofw', 'cofw68']: | |
img_dir = db_img_path + '/cofw/' | |
elif self.database_name in ['wflw']: | |
img_dir = db_img_path + '/wflw/' | |
return img_dir | |
def __str__(self): | |
state_dict = self.state_dict() | |
text = 'Dataloader {\n' | |
for k, v in state_dict.items(): | |
if isinstance(v, DatabaseStruct): | |
text += '\t{}: {}'.format(k, str(v).expandtabs(12)) | |
else: | |
text += '\t{}: {}\n'.format(k, v) | |
text += '\t}\n' | |
return text | |
class DatabaseStruct: | |
def __init__(self, database_name): | |
self.name = database_name | |
self.ldm_ids, self.ldm_flip_order, self.ldm_edges_matrix = self._get_database_specifics() | |
self.num_landmarks = len(self.ldm_ids) | |
self.num_edges = len(self.ldm_edges_matrix[0])-1 | |
self.fields = ['imgpath', 'bbox', 'headpose', 'ids', 'landmarks', 'visible'] | |
def _get_database_specifics(self): | |
'''Returns specifics ids and horizontal flip reorder''' | |
database_name = self.name | |
db_info_file = db_anns_path.format(database=database_name, file_name='db_info') | |
ldm_edges_matrix = None | |
if os.path.exists(db_info_file): | |
with open(db_info_file) as jsonfile: | |
db_info = json.load(jsonfile) | |
ldm_ids = db_info['ldm_ids'] | |
ldm_flip_order = db_info['ldm_flip_order'] | |
if 'ldm_edges_matrix' in db_info.keys(): | |
ldm_edges_matrix = db_info['ldm_edges_matrix'] | |
else: | |
raise ValueError('Database ' + database_name + 'specifics not defined. Missing db_info.json') | |
return ldm_ids, ldm_flip_order, ldm_edges_matrix | |
def state_dict(self): | |
state_dict = OrderedDict() | |
for k in self.__dict__.keys(): | |
if not k.startswith('_'): | |
state_dict[k] = getattr(self, k) | |
return state_dict | |
def __str__(self): | |
state_dict = self.state_dict() | |
text = 'Database {\n' | |
for k, v in state_dict.items(): | |
text += '\t{}: {}\n'.format(k, v) | |
text += '\t}\n' | |
return text | |