import json import os.path as osp import time import torch import numpy as np from tqdm import tqdm import torchvision.transforms as transforms from torch.utils.data import DataLoader, DistributedSampler import torch.optim as optim import torch.optim.lr_scheduler as lr_scheduler import torch.nn.functional as F # private package from external.landmark_detection.conf import * from external.landmark_detection.lib.dataset import AlignmentDataset from external.landmark_detection.lib.backbone import StackedHGNetV1 from external.landmark_detection.lib.loss import * from external.landmark_detection.lib.metric import NME, FR_AUC from external.landmark_detection.lib.utils import convert_secs2time from external.landmark_detection.lib.utils import AverageMeter def get_config(args): config = None config_name = args.config_name config = Alignment(args) return config def get_dataset(config, tsv_file, image_dir, loader_type, is_train): dataset = None if loader_type == "alignment": dataset = AlignmentDataset( tsv_file, image_dir, transforms.Compose([transforms.ToTensor()]), config.width, config.height, config.channels, config.means, config.scale, config.classes_num, config.crop_op, config.aug_prob, config.edge_info, config.flip_mapping, is_train, encoder_type=config.encoder_type ) else: assert False return dataset def get_dataloader(config, data_type, world_rank=0, world_size=1): loader = None if data_type == "train": dataset = get_dataset( config, config.train_tsv_file, config.train_pic_dir, config.loader_type, is_train=True) if world_size > 1: sampler = DistributedSampler(dataset, rank=world_rank, num_replicas=world_size, shuffle=True) loader = DataLoader(dataset, sampler=sampler, batch_size=config.batch_size // world_size, num_workers=config.train_num_workers, pin_memory=True, drop_last=True) else: loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.train_num_workers) elif data_type == "val": dataset = get_dataset( config, config.val_tsv_file, config.val_pic_dir, config.loader_type, is_train=False) loader = DataLoader(dataset, shuffle=False, batch_size=config.val_batch_size, num_workers=config.val_num_workers) elif data_type == "test": dataset = get_dataset( config, config.test_tsv_file, config.test_pic_dir, config.loader_type, is_train=False) loader = DataLoader(dataset, shuffle=False, batch_size=config.test_batch_size, num_workers=config.test_num_workers) else: assert False return loader def get_optimizer(config, net): params = net.parameters() optimizer = None if config.optimizer == "sgd": optimizer = optim.SGD( params, lr=config.learn_rate, momentum=config.momentum, weight_decay=config.weight_decay, nesterov=config.nesterov) elif config.optimizer == "adam": optimizer = optim.Adam( params, lr=config.learn_rate) elif config.optimizer == "rmsprop": optimizer = optim.RMSprop( params, lr=config.learn_rate, momentum=config.momentum, alpha=config.alpha, eps=config.epsilon, weight_decay=config.weight_decay ) else: assert False return optimizer def get_scheduler(config, optimizer): if config.scheduler == "MultiStepLR": scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=config.milestones, gamma=config.gamma) else: assert False return scheduler def get_net(config): net = None if config.net == "stackedHGnet_v1": net = StackedHGNetV1(config=config, classes_num=config.classes_num, edge_info=config.edge_info, nstack=config.nstack, add_coord=config.add_coord, decoder_type=config.decoder_type) else: assert False return net def get_criterions(config): criterions = list() for k in range(config.label_num): if config.criterions[k] == "AWingLoss": criterion = AWingLoss() elif config.criterions[k] == "smoothl1": criterion = SmoothL1Loss() elif config.criterions[k] == "l1": criterion = F.l1_loss elif config.criterions[k] == 'l2': criterion = F.mse_loss elif config.criterions[k] == "STARLoss": criterion = STARLoss(dist=config.star_dist, w=config.star_w) elif config.criterions[k] == "STARLoss_v2": criterion = STARLoss_v2(dist=config.star_dist, w=config.star_w) else: assert False criterions.append(criterion) return criterions def set_environment(config): if config.device_id >= 0: assert torch.cuda.is_available() and torch.cuda.device_count() > config.device_id torch.cuda.empty_cache() config.device = torch.device("cuda", config.device_id) config.use_gpu = True else: config.device = torch.device("cpu") config.use_gpu = False torch.set_default_dtype(torch.float32) torch.set_default_tensor_type(torch.FloatTensor) torch.set_flush_denormal(True) # ignore extremely small value torch.backends.cudnn.benchmark = True # This flag allows you to enable the inbuilt cudnn auto-tuner to find the best algorithm to use for your hardware. torch.autograd.set_detect_anomaly(True) def forward(config, test_loader, net): # ave_metrics = [[0, 0] for i in range(config.label_num)] list_nmes = [[] for i in range(config.label_num)] metric_nme = NME(nme_left_index=config.nme_left_index, nme_right_index=config.nme_right_index) metric_fr_auc = FR_AUC(data_definition=config.data_definition) output_pd = None net = net.float().to(config.device) net.eval() dataset_size = len(test_loader.dataset) batch_size = test_loader.batch_size if config.logger is not None: config.logger.info("Forward process, Dataset size: %d, Batch size: %d" % (dataset_size, batch_size)) for i, sample in enumerate(tqdm(test_loader)): input = sample["data"].float().to(config.device, non_blocking=True) labels = list() if isinstance(sample["label"], list): for label in sample["label"]: label = label.float().to(config.device, non_blocking=True) labels.append(label) else: label = sample["label"].float().to(config.device, non_blocking=True) for k in range(label.shape[1]): labels.append(label[:, k]) labels = config.nstack * labels with torch.no_grad(): output, heatmap, landmarks = net(input) # metrics for k in range(config.label_num): if config.metrics[k] is not None: list_nmes[k] += metric_nme.test(output[k], labels[k]) metrics = [[np.mean(nmes), ] + metric_fr_auc.test(nmes) for nmes in list_nmes] return output_pd, metrics def compute_loss(config, criterions, output, labels, heatmap=None, landmarks=None): batch_weight = 1.0 sum_loss = 0 losses = list() for k in range(config.label_num): if config.criterions[k] in ['smoothl1', 'l1', 'l2', 'WingLoss', 'AWingLoss']: loss = criterions[k](output[k], labels[k]) elif config.criterions[k] in ["STARLoss", "STARLoss_v2"]: _k = int(k / 3) if config.use_AAM else k loss = criterions[k](heatmap[_k], labels[k]) else: assert NotImplementedError loss = batch_weight * loss sum_loss += config.loss_weights[k] * loss loss = float(loss.data.cpu().item()) losses.append(loss) return losses, sum_loss def forward_backward(config, train_loader, net_module, net, net_ema, criterions, optimizer, epoch): train_model_time = AverageMeter() ave_losses = [0] * config.label_num net_module = net_module.float().to(config.device) net_module.train(True) dataset_size = len(train_loader.dataset) batch_size = config.batch_size # train_loader.batch_size batch_num = max(dataset_size / max(batch_size, 1), 1) if config.logger is not None: config.logger.info(config.note) config.logger.info("Forward Backward process, Dataset size: %d, Batch size: %d" % (dataset_size, batch_size)) iter_num = len(train_loader) epoch_start_time = time.time() if net_module != net: train_loader.sampler.set_epoch(epoch) for iter, sample in enumerate(train_loader): iter_start_time = time.time() # input input = sample["data"].float().to(config.device, non_blocking=True) # labels labels = list() if isinstance(sample["label"], list): for label in sample["label"]: label = label.float().to(config.device, non_blocking=True) labels.append(label) else: label = sample["label"].float().to(config.device, non_blocking=True) for k in range(label.shape[1]): labels.append(label[:, k]) labels = config.nstack * labels # forward output, heatmaps, landmarks = net_module(input) # loss losses, sum_loss = compute_loss(config, criterions, output, labels, heatmaps, landmarks) ave_losses = list(map(sum, zip(ave_losses, losses))) # backward optimizer.zero_grad() with torch.autograd.detect_anomaly(): sum_loss.backward() # torch.nn.utils.clip_grad_norm_(net_module.parameters(), 128.0) optimizer.step() if net_ema is not None: accumulate_net(net_ema, net, 0.5 ** (config.batch_size / 10000.0)) # accumulate_net(net_ema, net, 0.5 ** (8 / 10000.0)) # output train_model_time.update(time.time() - iter_start_time) last_time = convert_secs2time(train_model_time.avg * (iter_num - iter - 1), True) if iter % config.display_iteration == 0 or iter + 1 == len(train_loader): if config.logger is not None: losses_str = ' Average Loss: {:.6f}'.format(sum(losses) / len(losses)) for k, loss in enumerate(losses): losses_str += ', L{}: {:.3f}'.format(k, loss) config.logger.info( ' -->>[{:03d}/{:03d}][{:03d}/{:03d}]'.format(epoch, config.max_epoch, iter, iter_num) \ + last_time + losses_str) epoch_end_time = time.time() epoch_total_time = epoch_end_time - epoch_start_time epoch_load_data_time = epoch_total_time - train_model_time.sum if config.logger is not None: config.logger.info("Train/Epoch: %d/%d, Average total time cost per iteration in this epoch: %.6f" % ( epoch, config.max_epoch, epoch_total_time / iter_num)) config.logger.info("Train/Epoch: %d/%d, Average loading data time cost per iteration in this epoch: %.6f" % ( epoch, config.max_epoch, epoch_load_data_time / iter_num)) config.logger.info("Train/Epoch: %d/%d, Average training model time cost per iteration in this epoch: %.6f" % ( epoch, config.max_epoch, train_model_time.avg)) ave_losses = [loss / iter_num for loss in ave_losses] if config.logger is not None: config.logger.info("Train/Epoch: %d/%d, Average Loss in this epoch: %.6f" % ( epoch, config.max_epoch, sum(ave_losses) / len(ave_losses))) for k, ave_loss in enumerate(ave_losses): if config.logger is not None: config.logger.info("Train/Loss%03d in this epoch: %.6f" % (k, ave_loss)) def accumulate_net(model1, model2, decay): """ operation: model1 = model1 * decay + model2 * (1 - decay) """ par1 = dict(model1.named_parameters()) par2 = dict(model2.named_parameters()) for k in par1.keys(): par1[k].data.mul_(decay).add_( other=par2[k].data.to(par1[k].data.device), alpha=1 - decay) par1 = dict(model1.named_buffers()) par2 = dict(model2.named_buffers()) for k in par1.keys(): if par1[k].data.is_floating_point(): par1[k].data.mul_(decay).add_( other=par2[k].data.to(par1[k].data.device), alpha=1 - decay) else: par1[k].data = par2[k].data.to(par1[k].data.device) def save_model(config, epoch, net, net_ema, optimizer, scheduler, pytorch_model_path): # save pytorch model state = { "net": net.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), "epoch": epoch } if config.ema: state["net_ema"] = net_ema.state_dict() torch.save(state, pytorch_model_path) if config.logger is not None: config.logger.info("Epoch: %d/%d, model saved in this epoch" % (epoch, config.max_epoch))