Spaces:
Sleeping
Sleeping
import tensorflow as tf | |
from deep_heatmaps_model_fusion_net import DeepHeatmapsModel | |
import os | |
flags = tf.app.flags | |
# define paths | |
flags.DEFINE_string('output_dir', 'output', "directory for saving models, logs and samples") | |
flags.DEFINE_string('save_model_path', 'model', "directory for saving the model") | |
flags.DEFINE_string('save_sample_path', 'sample', | |
"directory for saving the sampled images, relevant if sample_to_log is False") | |
flags.DEFINE_string('save_log_path', 'logs', "directory for saving the log file") | |
flags.DEFINE_string('img_path', '~/landmark_detection_datasets', "data directory") | |
flags.DEFINE_string('valid_data', 'full', 'validation set to use: full/common/challenging/test') | |
flags.DEFINE_string('train_crop_dir', 'crop_gt_margin_0.25', "directory of train images cropped to bb (+margin)") | |
flags.DEFINE_string('img_dir_ns', 'crop_gt_margin_0.25_ns', "directory of train imgs cropped to bb + style transfer") | |
flags.DEFINE_string('epoch_data_dir', 'epoch_data', "directory containing pre-augmented data for each epoch") | |
flags.DEFINE_bool('use_epoch_data', False, "use pre-augmented data") | |
# logging parameters | |
flags.DEFINE_integer('print_every', 100, "print losses to screen + log every X steps") | |
flags.DEFINE_integer('save_every', 20000, "save model every X steps") | |
flags.DEFINE_integer('sample_every', 5000, "sample heatmaps + landmark predictions every X steps") | |
flags.DEFINE_integer('sample_grid', 4, 'number of training images in sample') | |
flags.DEFINE_bool('sample_to_log', True, 'samples will be saved to tensorboard log') | |
flags.DEFINE_integer('valid_size', 20, 'number of validation images to run') | |
flags.DEFINE_integer('log_valid_every', 10, 'evaluate on valid set every X epochs') | |
flags.DEFINE_integer('debug_data_size', 20, 'subset data size to test in debug mode') | |
flags.DEFINE_bool('debug', False, 'run in debug mode - use subset of the data') | |
# pretrain parameters (for fine-tuning / resume training) | |
flags.DEFINE_string('pre_train_path', 'model/deep_heatmaps-40000', 'pretrained model path') | |
flags.DEFINE_bool('load_pretrain', False, "load pretrained weight?") | |
flags.DEFINE_bool('load_primary_only', False, 'fine-tuning using only primary network weights') | |
# input data parameters | |
flags.DEFINE_integer('image_size', 256, "image size") | |
flags.DEFINE_integer('c_dim', 3, "color channels") | |
flags.DEFINE_integer('num_landmarks', 68, "number of face landmarks") | |
flags.DEFINE_float('sigma', 6, "std for heatmap generation gaussian") | |
flags.DEFINE_integer('scale', 1, 'scale for image normalization 255/1/0') | |
flags.DEFINE_float('margin', 0.25, 'margin for face crops - % of bb size') | |
flags.DEFINE_string('bb_type', 'gt', "bb to use - 'gt':for ground truth / 'init':for face detector output") | |
flags.DEFINE_float('win_mult', 3.33335, 'gaussian filter size for approx maps: 2 * sigma * win_mult + 1') | |
# optimization parameters | |
flags.DEFINE_float('l_weight_primary', 1., 'primary loss weight') | |
flags.DEFINE_float('l_weight_fusion', 0., 'fusion loss weight') | |
flags.DEFINE_float('l_weight_upsample', 3., 'upsample loss weight') | |
flags.DEFINE_integer('train_iter', 60000, 'maximum training iterations') | |
flags.DEFINE_integer('batch_size', 6, "batch_size") | |
flags.DEFINE_float('learning_rate', 1e-4, "initial learning rate") | |
flags.DEFINE_bool('adam_optimizer', True, "use adam optimizer (if False momentum optimizer is used)") | |
flags.DEFINE_float('momentum', 0.95, "optimizer momentum (if adam_optimizer==False)") | |
flags.DEFINE_integer('step', 100000, 'step for lr decay') | |
flags.DEFINE_float('gamma', 0.1, 'exponential base for lr decay') | |
flags.DEFINE_float('reg', 1e-5, 'scalar multiplier for weight decay (0 to disable)') | |
flags.DEFINE_string('weight_initializer', 'xavier', 'weight initializer: random_normal / xavier') | |
flags.DEFINE_float('weight_initializer_std', 0.01, 'std for random_normal weight initializer') | |
flags.DEFINE_float('bias_initializer', 0.0, 'constant value for bias initializer') | |
# augmentation parameters | |
flags.DEFINE_bool('augment_basic', True, "use basic augmentation?") | |
flags.DEFINE_bool('augment_texture', False, "use artistic texture augmentation?") | |
flags.DEFINE_float('p_texture', 0., 'probability of artistic texture augmentation') | |
flags.DEFINE_bool('augment_geom', False, "use artistic geometric augmentation?") | |
flags.DEFINE_float('p_geom', 0., 'probability of artistic geometric augmentation') | |
FLAGS = flags.FLAGS | |
if not os.path.exists(FLAGS.output_dir): | |
os.mkdir(FLAGS.output_dir) | |
def main(_): | |
save_model_path = os.path.join(FLAGS.output_dir, FLAGS.save_model_path) | |
save_sample_path = os.path.join(FLAGS.output_dir, FLAGS.save_sample_path) | |
save_log_path = os.path.join(FLAGS.output_dir, FLAGS.save_log_path) | |
# create directories if not exist | |
if not os.path.exists(save_model_path): | |
os.mkdir(save_model_path) | |
if not os.path.exists(save_log_path): | |
os.mkdir(save_log_path) | |
if not os.path.exists(save_sample_path) and not FLAGS.sample_to_log: | |
os.mkdir(save_sample_path) | |
model = DeepHeatmapsModel( | |
mode='TRAIN', train_iter=FLAGS.train_iter, batch_size=FLAGS.batch_size, learning_rate=FLAGS.learning_rate, | |
l_weight_primary=FLAGS.l_weight_primary, l_weight_fusion=FLAGS.l_weight_fusion, | |
l_weight_upsample=FLAGS.l_weight_upsample, reg=FLAGS.reg, adam_optimizer=FLAGS.adam_optimizer, | |
momentum=FLAGS.momentum, step=FLAGS.step, gamma=FLAGS.gamma, | |
weight_initializer=FLAGS.weight_initializer, weight_initializer_std=FLAGS.weight_initializer_std, | |
bias_initializer=FLAGS.bias_initializer, image_size=FLAGS.image_size, c_dim=FLAGS.c_dim, | |
num_landmarks=FLAGS.num_landmarks, sigma=FLAGS.sigma, scale=FLAGS.scale, margin=FLAGS.margin, | |
bb_type=FLAGS.bb_type, win_mult=FLAGS.win_mult, augment_basic=FLAGS.augment_basic, | |
augment_texture=FLAGS.augment_texture, p_texture=FLAGS.p_texture, augment_geom=FLAGS.augment_geom, | |
p_geom=FLAGS.p_geom, output_dir=FLAGS.output_dir, save_model_path=save_model_path, | |
save_sample_path=save_sample_path, save_log_path=save_log_path, pre_train_path=FLAGS.pre_train_path, | |
load_pretrain=FLAGS.load_pretrain, load_primary_only=FLAGS.load_primary_only, | |
img_path=FLAGS.img_path, valid_data=FLAGS.valid_data, valid_size=FLAGS.valid_size, | |
log_valid_every=FLAGS.log_valid_every, train_crop_dir=FLAGS.train_crop_dir, img_dir_ns=FLAGS.img_dir_ns, | |
print_every=FLAGS.print_every, save_every=FLAGS.save_every, sample_every=FLAGS.sample_every, | |
sample_grid=FLAGS.sample_grid, sample_to_log=FLAGS.sample_to_log, debug_data_size=FLAGS.debug_data_size, | |
debug=FLAGS.debug, use_epoch_data=FLAGS.use_epoch_data, epoch_data_dir=FLAGS.epoch_data_dir) | |
model.train() | |
if __name__ == '__main__': | |
tf.app.run() | |