|
import os.path as osp
|
|
from .base import Base
|
|
|
|
|
|
class Alignment(Base):
|
|
"""
|
|
Alignment configure file, which contains training parameters of alignment.
|
|
"""
|
|
|
|
def __init__(self, args):
|
|
super(Alignment, self).__init__('alignment')
|
|
self.ckpt_dir = '/mnt/workspace/humanAIGC/project/STAR/weights'
|
|
self.net = "stackedHGnet_v1"
|
|
self.nstack = 4
|
|
self.loader_type = "alignment"
|
|
self.data_definition = "300W"
|
|
self.test_file = "test.tsv"
|
|
|
|
|
|
self.channels = 3
|
|
self.width = 256
|
|
self.height = 256
|
|
self.means = (127.5, 127.5, 127.5)
|
|
self.scale = 1 / 127.5
|
|
self.aug_prob = 1.0
|
|
|
|
self.display_iteration = 10
|
|
self.val_epoch = 1
|
|
self.valset = "test.tsv"
|
|
self.norm_type = 'default'
|
|
self.encoder_type = 'default'
|
|
self.decoder_type = 'default'
|
|
|
|
|
|
self.milestones = [200, 350, 450]
|
|
self.max_epoch = 260
|
|
self.optimizer = "adam"
|
|
self.learn_rate = 0.001
|
|
self.weight_decay = 0.00001
|
|
self.betas = [0.9, 0.999]
|
|
self.gamma = 0.1
|
|
|
|
|
|
self.batch_size = 32
|
|
self.train_num_workers = 16
|
|
self.val_batch_size = 32
|
|
self.val_num_workers = 16
|
|
self.test_batch_size = 16
|
|
self.test_num_workers = 0
|
|
|
|
|
|
self.ema = True
|
|
self.add_coord = True
|
|
self.use_AAM = True
|
|
|
|
|
|
self.loss_func = "STARLoss_v2"
|
|
|
|
|
|
self.star_w = 1
|
|
self.star_dist = 'smoothl1'
|
|
|
|
self.init_from_args(args)
|
|
|
|
|
|
if self.data_definition == "COFW":
|
|
self.edge_info = (
|
|
(True, (0, 4, 2, 5)),
|
|
(True, (1, 6, 3, 7)),
|
|
(True, (8, 12, 10, 13)),
|
|
(False, (9, 14, 11, 15)),
|
|
(True, (18, 20, 19, 21)),
|
|
(True, (22, 26, 23, 27)),
|
|
(True, (22, 24, 23, 25)),
|
|
)
|
|
if self.norm_type == 'ocular':
|
|
self.nme_left_index = 8
|
|
self.nme_right_index = 9
|
|
elif self.norm_type in ['pupil', 'default']:
|
|
self.nme_left_index = 16
|
|
self.nme_right_index = 17
|
|
else:
|
|
raise NotImplementedError
|
|
self.classes_num = [29, 7, 29]
|
|
self.crop_op = True
|
|
self.flip_mapping = (
|
|
[0, 1], [4, 6], [2, 3], [5, 7], [8, 9], [10, 11], [12, 14], [16, 17], [13, 15], [18, 19], [22, 23],
|
|
)
|
|
self.image_dir = osp.join(self.image_dir, 'COFW')
|
|
|
|
elif self.data_definition == "300W":
|
|
self.edge_info = (
|
|
(False, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)),
|
|
(False, (17, 18, 19, 20, 21)),
|
|
(False, (22, 23, 24, 25, 26)),
|
|
(False, (27, 28, 29, 30)),
|
|
(False, (31, 32, 33, 34, 35)),
|
|
(True, (36, 37, 38, 39, 40, 41)),
|
|
(True, (42, 43, 44, 45, 46, 47)),
|
|
(True, (48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59)),
|
|
(True, (60, 61, 62, 63, 64, 65, 66, 67)),
|
|
)
|
|
if self.norm_type in ['ocular', 'default']:
|
|
self.nme_left_index = 36
|
|
self.nme_right_index = 45
|
|
elif self.norm_type == 'pupil':
|
|
self.nme_left_index = [36, 37, 38, 39, 40, 41]
|
|
self.nme_right_index = [42, 43, 44, 45, 46, 47]
|
|
else:
|
|
raise NotImplementedError
|
|
self.classes_num = [68, 9, 68]
|
|
self.crop_op = True
|
|
self.flip_mapping = (
|
|
[0, 16], [1, 15], [2, 14], [3, 13], [4, 12], [5, 11], [6, 10], [7, 9],
|
|
[17, 26], [18, 25], [19, 24], [20, 23], [21, 22],
|
|
[31, 35], [32, 34],
|
|
[36, 45], [37, 44], [38, 43], [39, 42], [40, 47], [41, 46],
|
|
[48, 54], [49, 53], [50, 52], [61, 63], [60, 64], [67, 65], [58, 56], [59, 55],
|
|
)
|
|
self.image_dir = osp.join(self.image_dir, '300W')
|
|
|
|
|
|
elif self.data_definition == "300VW":
|
|
self.edge_info = (
|
|
(False, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)),
|
|
(False, (17, 18, 19, 20, 21)),
|
|
(False, (22, 23, 24, 25, 26)),
|
|
(False, (27, 28, 29, 30)),
|
|
(False, (31, 32, 33, 34, 35)),
|
|
(True, (36, 37, 38, 39, 40, 41)),
|
|
(True, (42, 43, 44, 45, 46, 47)),
|
|
(True, (48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59)),
|
|
(True, (60, 61, 62, 63, 64, 65, 66, 67)),
|
|
)
|
|
if self.norm_type in ['ocular', 'default']:
|
|
self.nme_left_index = 36
|
|
self.nme_right_index = 45
|
|
elif self.norm_type == 'pupil':
|
|
self.nme_left_index = [36, 37, 38, 39, 40, 41]
|
|
self.nme_right_index = [42, 43, 44, 45, 46, 47]
|
|
else:
|
|
raise NotImplementedError
|
|
self.classes_num = [68, 9, 68]
|
|
self.crop_op = True
|
|
self.flip_mapping = (
|
|
[0, 16], [1, 15], [2, 14], [3, 13], [4, 12], [5, 11], [6, 10], [7, 9],
|
|
[17, 26], [18, 25], [19, 24], [20, 23], [21, 22],
|
|
[31, 35], [32, 34],
|
|
[36, 45], [37, 44], [38, 43], [39, 42], [40, 47], [41, 46],
|
|
[48, 54], [49, 53], [50, 52], [61, 63], [60, 64], [67, 65], [58, 56], [59, 55],
|
|
)
|
|
self.image_dir = osp.join(self.image_dir, '300VW_Dataset_2015_12_14')
|
|
|
|
elif self.data_definition == "WFLW":
|
|
self.edge_info = (
|
|
(False, (
|
|
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26,
|
|
27,
|
|
28, 29, 30, 31, 32)),
|
|
(True, (33, 34, 35, 36, 37, 38, 39, 40, 41)),
|
|
(True, (42, 43, 44, 45, 46, 47, 48, 49, 50)),
|
|
(False, (51, 52, 53, 54)),
|
|
(False, (55, 56, 57, 58, 59)),
|
|
(True, (60, 61, 62, 63, 64, 65, 66, 67)),
|
|
(True, (68, 69, 70, 71, 72, 73, 74, 75)),
|
|
(True, (76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87)),
|
|
(True, (88, 89, 90, 91, 92, 93, 94, 95)),
|
|
)
|
|
if self.norm_type in ['ocular', 'default']:
|
|
self.nme_left_index = 60
|
|
self.nme_right_index = 72
|
|
elif self.norm_type == 'pupil':
|
|
self.nme_left_index = 96
|
|
self.nme_right_index = 97
|
|
else:
|
|
raise NotImplementedError
|
|
self.classes_num = [98, 9, 98]
|
|
self.crop_op = True
|
|
self.flip_mapping = (
|
|
[0, 32], [1, 31], [2, 30], [3, 29], [4, 28], [5, 27], [6, 26], [7, 25], [8, 24], [9, 23], [10, 22],
|
|
[11, 21], [12, 20], [13, 19], [14, 18], [15, 17],
|
|
[33, 46], [34, 45], [35, 44], [36, 43], [37, 42], [38, 50], [39, 49], [40, 48], [41, 47],
|
|
[60, 72], [61, 71], [62, 70], [63, 69], [64, 68], [65, 75], [66, 74], [67, 73],
|
|
[55, 59], [56, 58],
|
|
[76, 82], [77, 81], [78, 80], [87, 83], [86, 84],
|
|
[88, 92], [89, 91], [95, 93], [96, 97]
|
|
)
|
|
self.image_dir = osp.join(self.image_dir, 'WFLW', 'WFLW_images')
|
|
|
|
self.label_num = self.nstack * 3 if self.use_AAM else self.nstack
|
|
self.loss_weights, self.criterions, self.metrics = [], [], []
|
|
for i in range(self.nstack):
|
|
factor = (2 ** i) / (2 ** (self.nstack - 1))
|
|
if self.use_AAM:
|
|
self.loss_weights += [factor * weight for weight in [1.0, 10.0, 10.0]]
|
|
self.criterions += [self.loss_func, "AWingLoss", "AWingLoss"]
|
|
self.metrics += ["NME", None, None]
|
|
else:
|
|
self.loss_weights += [factor * weight for weight in [1.0]]
|
|
self.criterions += [self.loss_func, ]
|
|
self.metrics += ["NME", ]
|
|
|
|
self.key_metric_index = (self.nstack - 1) * 3 if self.use_AAM else (self.nstack - 1)
|
|
|
|
|
|
self.folder = self.get_foldername()
|
|
self.work_dir = osp.join(self.ckpt_dir, self.data_definition, self.folder)
|
|
self.model_dir = osp.join(self.work_dir, 'model')
|
|
self.log_dir = osp.join(self.work_dir, 'log')
|
|
|
|
self.train_tsv_file = osp.join(self.annot_dir, self.data_definition, "train.tsv")
|
|
self.train_pic_dir = self.image_dir
|
|
|
|
self.val_tsv_file = osp.join(self.annot_dir, self.data_definition, self.valset)
|
|
self.val_pic_dir = self.image_dir
|
|
|
|
self.test_tsv_file = osp.join(self.annot_dir, self.data_definition, self.test_file)
|
|
self.test_pic_dir = self.image_dir
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_foldername(self):
|
|
str = ''
|
|
str += '{}_{}x{}_{}_ep{}_lr{}_bs{}'.format(self.data_definition, self.height, self.width,
|
|
self.optimizer, self.max_epoch, self.learn_rate, self.batch_size)
|
|
str += '_{}'.format(self.loss_func)
|
|
str += '_{}_{}'.format(self.star_dist, self.star_w) if self.loss_func == 'STARLoss' else ''
|
|
str += '_AAM' if self.use_AAM else ''
|
|
str += '_{}'.format(self.valset[:-4]) if self.valset != 'test.tsv' else ''
|
|
str += '_{}'.format(self.id)
|
|
return str
|
|
|