svjack's picture
Upload folder using huggingface_hub
d015578 verified
import os
import json
import pkg_resources
from collections import OrderedDict
# Default data paths
db_img_path = pkg_resources.resource_filename('spiga', 'data/databases')
db_anns_path = pkg_resources.resource_filename('spiga', 'data/annotations') + "/{database}/{file_name}.json"
class AlignConfig:
def __init__(self, database_name, mode='train'):
# Dataset
self.database_name = database_name
self.working_mode = mode
self.database = None # Set at self._update_database()
self.anns_file = None # Set at self._update_database()
self.image_dir = None # Set at self._update_database()
self._update_database()
self.image_size = (256, 256)
self.ftmap_size = (256, 256)
# Dataloaders
self.ids = None # List of a subset if need it
self.shuffle = True # Shuffle samples
self.num_workers = 4 # Threads
# Posit
self.generate_pose = True # Generate pose parameters from landmarks
self.focal_ratio = 1.5 # Camera matrix focal length ratio
self.posit_max_iter = 100 # Refinement iterations
# Subset of robust ids in the 3D model to use in posit.
# 'None' to use all the available model landmarks.
self.posit_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]
# Data augmentation
# Control augmentations with the following list, crop to self.img_size is mandatory, check target_dist param.
if mode == 'train':
self.aug_names = ['flip', 'rotate_scale', 'occlusion', 'lighting', 'blur']
else:
self.aug_names = []
self.shuffle = False
# Flip
self.hflip_prob = 0.5
# Rotation
self.angle_range = 45.
# Scale
self.scale_max = 0.15
self.scale_min = -0.15
# Translation
self.trl_ratio = 0.05 # Translation augmentation
# Crop target rescale
self.target_dist = 1.6 # Target distance zoom in/out around face. Default: 1.
# Occlusion
self.occluded_max_len = 0.4
self.occluded_min_len = 0.1
self.occluded_covar_ratio = 2.25**0.5
# Lighting
self.hsv_range_min = [-0.5, -0.5, -0.5]
self.hsv_range_max = [0.5, 0.5, 0.5]
# Blur
self.blur_prob = 0.5
self.blur_kernel_range = [0, 2]
# Heatmaps 2D
self.sigma2D = 1.5
self.heatmap2D_norm = False
# Boundaries
self.sigmaBD = 1
def update(self, params_dict):
state_dict = self.state_dict()
for k, v in params_dict.items():
if k in state_dict or hasattr(self, k):
setattr(self, k, v)
else:
Warning('Unknown option: {}: {}'.format(k, v))
self._update_database()
def state_dict(self, tojson=False):
state_dict = OrderedDict()
for k in self.__dict__.keys():
if not k.startswith('_'):
if tojson and k in ['database']:
continue
state_dict[k] = getattr(self, k)
return state_dict
def _update_database(self):
self.database = DatabaseStruct(self.database_name)
self.anns_file = db_anns_path.format(database=self.database_name, file_name=self.working_mode)
self.image_dir = self._get_imgdb_path()
def _get_imgdb_path(self):
img_dir = None
if self.database_name in ['300wpublic', '300wprivate']:
img_dir = db_img_path + '/300w/'
elif self.database_name in ['aflw19', 'merlrav']:
img_dir = db_img_path + '/aflw/data/'
elif self.database_name in ['cofw', 'cofw68']:
img_dir = db_img_path + '/cofw/'
elif self.database_name in ['wflw']:
img_dir = db_img_path + '/wflw/'
return img_dir
def __str__(self):
state_dict = self.state_dict()
text = 'Dataloader {\n'
for k, v in state_dict.items():
if isinstance(v, DatabaseStruct):
text += '\t{}: {}'.format(k, str(v).expandtabs(12))
else:
text += '\t{}: {}\n'.format(k, v)
text += '\t}\n'
return text
class DatabaseStruct:
def __init__(self, database_name):
self.name = database_name
self.ldm_ids, self.ldm_flip_order, self.ldm_edges_matrix = self._get_database_specifics()
self.num_landmarks = len(self.ldm_ids)
self.num_edges = len(self.ldm_edges_matrix[0])-1
self.fields = ['imgpath', 'bbox', 'headpose', 'ids', 'landmarks', 'visible']
def _get_database_specifics(self):
'''Returns specifics ids and horizontal flip reorder'''
database_name = self.name
db_info_file = db_anns_path.format(database=database_name, file_name='db_info')
ldm_edges_matrix = None
if os.path.exists(db_info_file):
with open(db_info_file) as jsonfile:
db_info = json.load(jsonfile)
ldm_ids = db_info['ldm_ids']
ldm_flip_order = db_info['ldm_flip_order']
if 'ldm_edges_matrix' in db_info.keys():
ldm_edges_matrix = db_info['ldm_edges_matrix']
else:
raise ValueError('Database ' + database_name + 'specifics not defined. Missing db_info.json')
return ldm_ids, ldm_flip_order, ldm_edges_matrix
def state_dict(self):
state_dict = OrderedDict()
for k in self.__dict__.keys():
if not k.startswith('_'):
state_dict[k] = getattr(self, k)
return state_dict
def __str__(self):
state_dict = self.state_dict()
text = 'Database {\n'
for k, v in state_dict.items():
text += '\t{}: {}\n'.format(k, v)
text += '\t}\n'
return text