|
import json
|
|
import os.path as osp
|
|
import time
|
|
import torch
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
|
|
import torchvision.transforms as transforms
|
|
from torch.utils.data import DataLoader, DistributedSampler
|
|
import torch.optim as optim
|
|
import torch.optim.lr_scheduler as lr_scheduler
|
|
import torch.nn.functional as F
|
|
|
|
|
|
from external.landmark_detection.conf import *
|
|
from external.landmark_detection.lib.dataset import AlignmentDataset
|
|
from external.landmark_detection.lib.backbone import StackedHGNetV1
|
|
from external.landmark_detection.lib.loss import *
|
|
from external.landmark_detection.lib.metric import NME, FR_AUC
|
|
from external.landmark_detection.lib.utils import convert_secs2time
|
|
from external.landmark_detection.lib.utils import AverageMeter
|
|
|
|
|
|
def get_config(args):
|
|
config = None
|
|
config_name = args.config_name
|
|
config = Alignment(args)
|
|
|
|
|
|
return config
|
|
|
|
|
|
def get_dataset(config, tsv_file, image_dir, loader_type, is_train):
|
|
dataset = None
|
|
if loader_type == "alignment":
|
|
dataset = AlignmentDataset(
|
|
tsv_file,
|
|
image_dir,
|
|
transforms.Compose([transforms.ToTensor()]),
|
|
config.width,
|
|
config.height,
|
|
config.channels,
|
|
config.means,
|
|
config.scale,
|
|
config.classes_num,
|
|
config.crop_op,
|
|
config.aug_prob,
|
|
config.edge_info,
|
|
config.flip_mapping,
|
|
is_train,
|
|
encoder_type=config.encoder_type
|
|
)
|
|
else:
|
|
assert False
|
|
return dataset
|
|
|
|
|
|
def get_dataloader(config, data_type, world_rank=0, world_size=1):
|
|
loader = None
|
|
if data_type == "train":
|
|
dataset = get_dataset(
|
|
config,
|
|
config.train_tsv_file,
|
|
config.train_pic_dir,
|
|
config.loader_type,
|
|
is_train=True)
|
|
if world_size > 1:
|
|
sampler = DistributedSampler(dataset, rank=world_rank, num_replicas=world_size, shuffle=True)
|
|
loader = DataLoader(dataset, sampler=sampler, batch_size=config.batch_size // world_size,
|
|
num_workers=config.train_num_workers, pin_memory=True, drop_last=True)
|
|
else:
|
|
loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True,
|
|
num_workers=config.train_num_workers)
|
|
elif data_type == "val":
|
|
dataset = get_dataset(
|
|
config,
|
|
config.val_tsv_file,
|
|
config.val_pic_dir,
|
|
config.loader_type,
|
|
is_train=False)
|
|
loader = DataLoader(dataset, shuffle=False, batch_size=config.val_batch_size,
|
|
num_workers=config.val_num_workers)
|
|
elif data_type == "test":
|
|
dataset = get_dataset(
|
|
config,
|
|
config.test_tsv_file,
|
|
config.test_pic_dir,
|
|
config.loader_type,
|
|
is_train=False)
|
|
loader = DataLoader(dataset, shuffle=False, batch_size=config.test_batch_size,
|
|
num_workers=config.test_num_workers)
|
|
else:
|
|
assert False
|
|
return loader
|
|
|
|
|
|
def get_optimizer(config, net):
|
|
params = net.parameters()
|
|
|
|
optimizer = None
|
|
if config.optimizer == "sgd":
|
|
optimizer = optim.SGD(
|
|
params,
|
|
lr=config.learn_rate,
|
|
momentum=config.momentum,
|
|
weight_decay=config.weight_decay,
|
|
nesterov=config.nesterov)
|
|
elif config.optimizer == "adam":
|
|
optimizer = optim.Adam(
|
|
params,
|
|
lr=config.learn_rate)
|
|
elif config.optimizer == "rmsprop":
|
|
optimizer = optim.RMSprop(
|
|
params,
|
|
lr=config.learn_rate,
|
|
momentum=config.momentum,
|
|
alpha=config.alpha,
|
|
eps=config.epsilon,
|
|
weight_decay=config.weight_decay
|
|
)
|
|
else:
|
|
assert False
|
|
return optimizer
|
|
|
|
|
|
def get_scheduler(config, optimizer):
|
|
if config.scheduler == "MultiStepLR":
|
|
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=config.milestones, gamma=config.gamma)
|
|
else:
|
|
assert False
|
|
return scheduler
|
|
|
|
|
|
def get_net(config):
|
|
net = None
|
|
if config.net == "stackedHGnet_v1":
|
|
net = StackedHGNetV1(config=config,
|
|
classes_num=config.classes_num,
|
|
edge_info=config.edge_info,
|
|
nstack=config.nstack,
|
|
add_coord=config.add_coord,
|
|
decoder_type=config.decoder_type)
|
|
else:
|
|
assert False
|
|
return net
|
|
|
|
|
|
def get_criterions(config):
|
|
criterions = list()
|
|
for k in range(config.label_num):
|
|
if config.criterions[k] == "AWingLoss":
|
|
criterion = AWingLoss()
|
|
elif config.criterions[k] == "smoothl1":
|
|
criterion = SmoothL1Loss()
|
|
elif config.criterions[k] == "l1":
|
|
criterion = F.l1_loss
|
|
elif config.criterions[k] == 'l2':
|
|
criterion = F.mse_loss
|
|
elif config.criterions[k] == "STARLoss":
|
|
criterion = STARLoss(dist=config.star_dist, w=config.star_w)
|
|
elif config.criterions[k] == "STARLoss_v2":
|
|
criterion = STARLoss_v2(dist=config.star_dist, w=config.star_w)
|
|
else:
|
|
assert False
|
|
criterions.append(criterion)
|
|
return criterions
|
|
|
|
|
|
def set_environment(config):
|
|
if config.device_id >= 0:
|
|
assert torch.cuda.is_available() and torch.cuda.device_count() > config.device_id
|
|
torch.cuda.empty_cache()
|
|
config.device = torch.device("cuda", config.device_id)
|
|
config.use_gpu = True
|
|
else:
|
|
config.device = torch.device("cpu")
|
|
config.use_gpu = False
|
|
|
|
torch.set_default_dtype(torch.float32)
|
|
torch.set_default_tensor_type(torch.FloatTensor)
|
|
torch.set_flush_denormal(True)
|
|
torch.backends.cudnn.benchmark = True
|
|
torch.autograd.set_detect_anomaly(True)
|
|
|
|
|
|
def forward(config, test_loader, net):
|
|
|
|
list_nmes = [[] for i in range(config.label_num)]
|
|
metric_nme = NME(nme_left_index=config.nme_left_index, nme_right_index=config.nme_right_index)
|
|
metric_fr_auc = FR_AUC(data_definition=config.data_definition)
|
|
|
|
output_pd = None
|
|
|
|
net = net.float().to(config.device)
|
|
net.eval()
|
|
dataset_size = len(test_loader.dataset)
|
|
batch_size = test_loader.batch_size
|
|
if config.logger is not None:
|
|
config.logger.info("Forward process, Dataset size: %d, Batch size: %d" % (dataset_size, batch_size))
|
|
for i, sample in enumerate(tqdm(test_loader)):
|
|
input = sample["data"].float().to(config.device, non_blocking=True)
|
|
labels = list()
|
|
if isinstance(sample["label"], list):
|
|
for label in sample["label"]:
|
|
label = label.float().to(config.device, non_blocking=True)
|
|
labels.append(label)
|
|
else:
|
|
label = sample["label"].float().to(config.device, non_blocking=True)
|
|
for k in range(label.shape[1]):
|
|
labels.append(label[:, k])
|
|
labels = config.nstack * labels
|
|
|
|
with torch.no_grad():
|
|
output, heatmap, landmarks = net(input)
|
|
|
|
|
|
for k in range(config.label_num):
|
|
if config.metrics[k] is not None:
|
|
list_nmes[k] += metric_nme.test(output[k], labels[k])
|
|
|
|
metrics = [[np.mean(nmes), ] + metric_fr_auc.test(nmes) for nmes in list_nmes]
|
|
|
|
return output_pd, metrics
|
|
|
|
|
|
def compute_loss(config, criterions, output, labels, heatmap=None, landmarks=None):
|
|
batch_weight = 1.0
|
|
sum_loss = 0
|
|
losses = list()
|
|
for k in range(config.label_num):
|
|
if config.criterions[k] in ['smoothl1', 'l1', 'l2', 'WingLoss', 'AWingLoss']:
|
|
loss = criterions[k](output[k], labels[k])
|
|
elif config.criterions[k] in ["STARLoss", "STARLoss_v2"]:
|
|
_k = int(k / 3) if config.use_AAM else k
|
|
loss = criterions[k](heatmap[_k], labels[k])
|
|
else:
|
|
assert NotImplementedError
|
|
loss = batch_weight * loss
|
|
sum_loss += config.loss_weights[k] * loss
|
|
loss = float(loss.data.cpu().item())
|
|
losses.append(loss)
|
|
return losses, sum_loss
|
|
|
|
|
|
def forward_backward(config, train_loader, net_module, net, net_ema, criterions, optimizer, epoch):
|
|
train_model_time = AverageMeter()
|
|
ave_losses = [0] * config.label_num
|
|
|
|
net_module = net_module.float().to(config.device)
|
|
net_module.train(True)
|
|
dataset_size = len(train_loader.dataset)
|
|
batch_size = config.batch_size
|
|
batch_num = max(dataset_size / max(batch_size, 1), 1)
|
|
if config.logger is not None:
|
|
config.logger.info(config.note)
|
|
config.logger.info("Forward Backward process, Dataset size: %d, Batch size: %d" % (dataset_size, batch_size))
|
|
|
|
iter_num = len(train_loader)
|
|
epoch_start_time = time.time()
|
|
if net_module != net:
|
|
train_loader.sampler.set_epoch(epoch)
|
|
for iter, sample in enumerate(train_loader):
|
|
iter_start_time = time.time()
|
|
|
|
input = sample["data"].float().to(config.device, non_blocking=True)
|
|
|
|
labels = list()
|
|
if isinstance(sample["label"], list):
|
|
for label in sample["label"]:
|
|
label = label.float().to(config.device, non_blocking=True)
|
|
labels.append(label)
|
|
else:
|
|
label = sample["label"].float().to(config.device, non_blocking=True)
|
|
for k in range(label.shape[1]):
|
|
labels.append(label[:, k])
|
|
labels = config.nstack * labels
|
|
|
|
output, heatmaps, landmarks = net_module(input)
|
|
|
|
|
|
losses, sum_loss = compute_loss(config, criterions, output, labels, heatmaps, landmarks)
|
|
ave_losses = list(map(sum, zip(ave_losses, losses)))
|
|
|
|
|
|
optimizer.zero_grad()
|
|
with torch.autograd.detect_anomaly():
|
|
sum_loss.backward()
|
|
|
|
optimizer.step()
|
|
|
|
if net_ema is not None:
|
|
accumulate_net(net_ema, net, 0.5 ** (config.batch_size / 10000.0))
|
|
|
|
|
|
|
|
train_model_time.update(time.time() - iter_start_time)
|
|
last_time = convert_secs2time(train_model_time.avg * (iter_num - iter - 1), True)
|
|
if iter % config.display_iteration == 0 or iter + 1 == len(train_loader):
|
|
if config.logger is not None:
|
|
losses_str = ' Average Loss: {:.6f}'.format(sum(losses) / len(losses))
|
|
for k, loss in enumerate(losses):
|
|
losses_str += ', L{}: {:.3f}'.format(k, loss)
|
|
config.logger.info(
|
|
' -->>[{:03d}/{:03d}][{:03d}/{:03d}]'.format(epoch, config.max_epoch, iter, iter_num) \
|
|
+ last_time + losses_str)
|
|
|
|
epoch_end_time = time.time()
|
|
epoch_total_time = epoch_end_time - epoch_start_time
|
|
epoch_load_data_time = epoch_total_time - train_model_time.sum
|
|
if config.logger is not None:
|
|
config.logger.info("Train/Epoch: %d/%d, Average total time cost per iteration in this epoch: %.6f" % (
|
|
epoch, config.max_epoch, epoch_total_time / iter_num))
|
|
config.logger.info("Train/Epoch: %d/%d, Average loading data time cost per iteration in this epoch: %.6f" % (
|
|
epoch, config.max_epoch, epoch_load_data_time / iter_num))
|
|
config.logger.info("Train/Epoch: %d/%d, Average training model time cost per iteration in this epoch: %.6f" % (
|
|
epoch, config.max_epoch, train_model_time.avg))
|
|
|
|
ave_losses = [loss / iter_num for loss in ave_losses]
|
|
if config.logger is not None:
|
|
config.logger.info("Train/Epoch: %d/%d, Average Loss in this epoch: %.6f" % (
|
|
epoch, config.max_epoch, sum(ave_losses) / len(ave_losses)))
|
|
for k, ave_loss in enumerate(ave_losses):
|
|
if config.logger is not None:
|
|
config.logger.info("Train/Loss%03d in this epoch: %.6f" % (k, ave_loss))
|
|
|
|
|
|
def accumulate_net(model1, model2, decay):
|
|
"""
|
|
operation: model1 = model1 * decay + model2 * (1 - decay)
|
|
"""
|
|
par1 = dict(model1.named_parameters())
|
|
par2 = dict(model2.named_parameters())
|
|
for k in par1.keys():
|
|
par1[k].data.mul_(decay).add_(
|
|
other=par2[k].data.to(par1[k].data.device),
|
|
alpha=1 - decay)
|
|
|
|
par1 = dict(model1.named_buffers())
|
|
par2 = dict(model2.named_buffers())
|
|
for k in par1.keys():
|
|
if par1[k].data.is_floating_point():
|
|
par1[k].data.mul_(decay).add_(
|
|
other=par2[k].data.to(par1[k].data.device),
|
|
alpha=1 - decay)
|
|
else:
|
|
par1[k].data = par2[k].data.to(par1[k].data.device)
|
|
|
|
|
|
def save_model(config, epoch, net, net_ema, optimizer, scheduler, pytorch_model_path):
|
|
|
|
state = {
|
|
"net": net.state_dict(),
|
|
"optimizer": optimizer.state_dict(),
|
|
"scheduler": scheduler.state_dict(),
|
|
"epoch": epoch
|
|
}
|
|
if config.ema:
|
|
state["net_ema"] = net_ema.state_dict()
|
|
|
|
torch.save(state, pytorch_model_path)
|
|
if config.logger is not None:
|
|
config.logger.info("Epoch: %d/%d, model saved in this epoch" % (epoch, config.max_epoch))
|
|
|