Spaces:
Running
Running
File size: 5,944 Bytes
d015578 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import os
import json
import pkg_resources
from collections import OrderedDict
# Default data paths
db_img_path = pkg_resources.resource_filename('spiga', 'data/databases')
db_anns_path = pkg_resources.resource_filename('spiga', 'data/annotations') + "/{database}/{file_name}.json"
class AlignConfig:
def __init__(self, database_name, mode='train'):
# Dataset
self.database_name = database_name
self.working_mode = mode
self.database = None # Set at self._update_database()
self.anns_file = None # Set at self._update_database()
self.image_dir = None # Set at self._update_database()
self._update_database()
self.image_size = (256, 256)
self.ftmap_size = (256, 256)
# Dataloaders
self.ids = None # List of a subset if need it
self.shuffle = True # Shuffle samples
self.num_workers = 4 # Threads
# Posit
self.generate_pose = True # Generate pose parameters from landmarks
self.focal_ratio = 1.5 # Camera matrix focal length ratio
self.posit_max_iter = 100 # Refinement iterations
# Subset of robust ids in the 3D model to use in posit.
# 'None' to use all the available model landmarks.
self.posit_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]
# Data augmentation
# Control augmentations with the following list, crop to self.img_size is mandatory, check target_dist param.
if mode == 'train':
self.aug_names = ['flip', 'rotate_scale', 'occlusion', 'lighting', 'blur']
else:
self.aug_names = []
self.shuffle = False
# Flip
self.hflip_prob = 0.5
# Rotation
self.angle_range = 45.
# Scale
self.scale_max = 0.15
self.scale_min = -0.15
# Translation
self.trl_ratio = 0.05 # Translation augmentation
# Crop target rescale
self.target_dist = 1.6 # Target distance zoom in/out around face. Default: 1.
# Occlusion
self.occluded_max_len = 0.4
self.occluded_min_len = 0.1
self.occluded_covar_ratio = 2.25**0.5
# Lighting
self.hsv_range_min = [-0.5, -0.5, -0.5]
self.hsv_range_max = [0.5, 0.5, 0.5]
# Blur
self.blur_prob = 0.5
self.blur_kernel_range = [0, 2]
# Heatmaps 2D
self.sigma2D = 1.5
self.heatmap2D_norm = False
# Boundaries
self.sigmaBD = 1
def update(self, params_dict):
state_dict = self.state_dict()
for k, v in params_dict.items():
if k in state_dict or hasattr(self, k):
setattr(self, k, v)
else:
Warning('Unknown option: {}: {}'.format(k, v))
self._update_database()
def state_dict(self, tojson=False):
state_dict = OrderedDict()
for k in self.__dict__.keys():
if not k.startswith('_'):
if tojson and k in ['database']:
continue
state_dict[k] = getattr(self, k)
return state_dict
def _update_database(self):
self.database = DatabaseStruct(self.database_name)
self.anns_file = db_anns_path.format(database=self.database_name, file_name=self.working_mode)
self.image_dir = self._get_imgdb_path()
def _get_imgdb_path(self):
img_dir = None
if self.database_name in ['300wpublic', '300wprivate']:
img_dir = db_img_path + '/300w/'
elif self.database_name in ['aflw19', 'merlrav']:
img_dir = db_img_path + '/aflw/data/'
elif self.database_name in ['cofw', 'cofw68']:
img_dir = db_img_path + '/cofw/'
elif self.database_name in ['wflw']:
img_dir = db_img_path + '/wflw/'
return img_dir
def __str__(self):
state_dict = self.state_dict()
text = 'Dataloader {\n'
for k, v in state_dict.items():
if isinstance(v, DatabaseStruct):
text += '\t{}: {}'.format(k, str(v).expandtabs(12))
else:
text += '\t{}: {}\n'.format(k, v)
text += '\t}\n'
return text
class DatabaseStruct:
def __init__(self, database_name):
self.name = database_name
self.ldm_ids, self.ldm_flip_order, self.ldm_edges_matrix = self._get_database_specifics()
self.num_landmarks = len(self.ldm_ids)
self.num_edges = len(self.ldm_edges_matrix[0])-1
self.fields = ['imgpath', 'bbox', 'headpose', 'ids', 'landmarks', 'visible']
def _get_database_specifics(self):
'''Returns specifics ids and horizontal flip reorder'''
database_name = self.name
db_info_file = db_anns_path.format(database=database_name, file_name='db_info')
ldm_edges_matrix = None
if os.path.exists(db_info_file):
with open(db_info_file) as jsonfile:
db_info = json.load(jsonfile)
ldm_ids = db_info['ldm_ids']
ldm_flip_order = db_info['ldm_flip_order']
if 'ldm_edges_matrix' in db_info.keys():
ldm_edges_matrix = db_info['ldm_edges_matrix']
else:
raise ValueError('Database ' + database_name + 'specifics not defined. Missing db_info.json')
return ldm_ids, ldm_flip_order, ldm_edges_matrix
def state_dict(self):
state_dict = OrderedDict()
for k in self.__dict__.keys():
if not k.startswith('_'):
state_dict[k] = getattr(self, k)
return state_dict
def __str__(self):
state_dict = self.state_dict()
text = 'Database {\n'
for k, v in state_dict.items():
text += '\t{}: {}\n'.format(k, v)
text += '\t}\n'
return text
|