Spaces:
Running
on
Zero
Running
on
Zero
import json | |
import os.path as osp | |
import time | |
import torch | |
import numpy as np | |
from tqdm import tqdm | |
import torchvision.transforms as transforms | |
from torch.utils.data import DataLoader, DistributedSampler | |
import torch.optim as optim | |
import torch.optim.lr_scheduler as lr_scheduler | |
import torch.nn.functional as F | |
# private package | |
from external.landmark_detection.conf import * | |
from external.landmark_detection.lib.dataset import AlignmentDataset | |
from external.landmark_detection.lib.backbone import StackedHGNetV1 | |
from external.landmark_detection.lib.loss import * | |
from external.landmark_detection.lib.metric import NME, FR_AUC | |
from external.landmark_detection.lib.utils import convert_secs2time | |
from external.landmark_detection.lib.utils import AverageMeter | |
def get_config(args): | |
config = None | |
config_name = args.config_name | |
config = Alignment(args) | |
return config | |
def get_dataset(config, tsv_file, image_dir, loader_type, is_train): | |
dataset = None | |
if loader_type == "alignment": | |
dataset = AlignmentDataset( | |
tsv_file, | |
image_dir, | |
transforms.Compose([transforms.ToTensor()]), | |
config.width, | |
config.height, | |
config.channels, | |
config.means, | |
config.scale, | |
config.classes_num, | |
config.crop_op, | |
config.aug_prob, | |
config.edge_info, | |
config.flip_mapping, | |
is_train, | |
encoder_type=config.encoder_type | |
) | |
else: | |
assert False | |
return dataset | |
def get_dataloader(config, data_type, world_rank=0, world_size=1): | |
loader = None | |
if data_type == "train": | |
dataset = get_dataset( | |
config, | |
config.train_tsv_file, | |
config.train_pic_dir, | |
config.loader_type, | |
is_train=True) | |
if world_size > 1: | |
sampler = DistributedSampler(dataset, rank=world_rank, num_replicas=world_size, shuffle=True) | |
loader = DataLoader(dataset, sampler=sampler, batch_size=config.batch_size // world_size, | |
num_workers=config.train_num_workers, pin_memory=True, drop_last=True) | |
else: | |
loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, | |
num_workers=config.train_num_workers) | |
elif data_type == "val": | |
dataset = get_dataset( | |
config, | |
config.val_tsv_file, | |
config.val_pic_dir, | |
config.loader_type, | |
is_train=False) | |
loader = DataLoader(dataset, shuffle=False, batch_size=config.val_batch_size, | |
num_workers=config.val_num_workers) | |
elif data_type == "test": | |
dataset = get_dataset( | |
config, | |
config.test_tsv_file, | |
config.test_pic_dir, | |
config.loader_type, | |
is_train=False) | |
loader = DataLoader(dataset, shuffle=False, batch_size=config.test_batch_size, | |
num_workers=config.test_num_workers) | |
else: | |
assert False | |
return loader | |
def get_optimizer(config, net): | |
params = net.parameters() | |
optimizer = None | |
if config.optimizer == "sgd": | |
optimizer = optim.SGD( | |
params, | |
lr=config.learn_rate, | |
momentum=config.momentum, | |
weight_decay=config.weight_decay, | |
nesterov=config.nesterov) | |
elif config.optimizer == "adam": | |
optimizer = optim.Adam( | |
params, | |
lr=config.learn_rate) | |
elif config.optimizer == "rmsprop": | |
optimizer = optim.RMSprop( | |
params, | |
lr=config.learn_rate, | |
momentum=config.momentum, | |
alpha=config.alpha, | |
eps=config.epsilon, | |
weight_decay=config.weight_decay | |
) | |
else: | |
assert False | |
return optimizer | |
def get_scheduler(config, optimizer): | |
if config.scheduler == "MultiStepLR": | |
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=config.milestones, gamma=config.gamma) | |
else: | |
assert False | |
return scheduler | |
def get_net(config): | |
net = None | |
if config.net == "stackedHGnet_v1": | |
net = StackedHGNetV1(config=config, | |
classes_num=config.classes_num, | |
edge_info=config.edge_info, | |
nstack=config.nstack, | |
add_coord=config.add_coord, | |
decoder_type=config.decoder_type) | |
else: | |
assert False | |
return net | |
def get_criterions(config): | |
criterions = list() | |
for k in range(config.label_num): | |
if config.criterions[k] == "AWingLoss": | |
criterion = AWingLoss() | |
elif config.criterions[k] == "smoothl1": | |
criterion = SmoothL1Loss() | |
elif config.criterions[k] == "l1": | |
criterion = F.l1_loss | |
elif config.criterions[k] == 'l2': | |
criterion = F.mse_loss | |
elif config.criterions[k] == "STARLoss": | |
criterion = STARLoss(dist=config.star_dist, w=config.star_w) | |
elif config.criterions[k] == "STARLoss_v2": | |
criterion = STARLoss_v2(dist=config.star_dist, w=config.star_w) | |
else: | |
assert False | |
criterions.append(criterion) | |
return criterions | |
def set_environment(config): | |
if config.device_id >= 0: | |
assert torch.cuda.is_available() and torch.cuda.device_count() > config.device_id | |
torch.cuda.empty_cache() | |
config.device = torch.device("cuda", config.device_id) | |
config.use_gpu = True | |
else: | |
config.device = torch.device("cpu") | |
config.use_gpu = False | |
torch.set_default_dtype(torch.float32) | |
torch.set_default_tensor_type(torch.FloatTensor) | |
torch.set_flush_denormal(True) # ignore extremely small value | |
torch.backends.cudnn.benchmark = True # This flag allows you to enable the inbuilt cudnn auto-tuner to find the best algorithm to use for your hardware. | |
torch.autograd.set_detect_anomaly(True) | |
def forward(config, test_loader, net): | |
# ave_metrics = [[0, 0] for i in range(config.label_num)] | |
list_nmes = [[] for i in range(config.label_num)] | |
metric_nme = NME(nme_left_index=config.nme_left_index, nme_right_index=config.nme_right_index) | |
metric_fr_auc = FR_AUC(data_definition=config.data_definition) | |
output_pd = None | |
net = net.float().to(config.device) | |
net.eval() | |
dataset_size = len(test_loader.dataset) | |
batch_size = test_loader.batch_size | |
if config.logger is not None: | |
config.logger.info("Forward process, Dataset size: %d, Batch size: %d" % (dataset_size, batch_size)) | |
for i, sample in enumerate(tqdm(test_loader)): | |
input = sample["data"].float().to(config.device, non_blocking=True) | |
labels = list() | |
if isinstance(sample["label"], list): | |
for label in sample["label"]: | |
label = label.float().to(config.device, non_blocking=True) | |
labels.append(label) | |
else: | |
label = sample["label"].float().to(config.device, non_blocking=True) | |
for k in range(label.shape[1]): | |
labels.append(label[:, k]) | |
labels = config.nstack * labels | |
with torch.no_grad(): | |
output, heatmap, landmarks = net(input) | |
# metrics | |
for k in range(config.label_num): | |
if config.metrics[k] is not None: | |
list_nmes[k] += metric_nme.test(output[k], labels[k]) | |
metrics = [[np.mean(nmes), ] + metric_fr_auc.test(nmes) for nmes in list_nmes] | |
return output_pd, metrics | |
def compute_loss(config, criterions, output, labels, heatmap=None, landmarks=None): | |
batch_weight = 1.0 | |
sum_loss = 0 | |
losses = list() | |
for k in range(config.label_num): | |
if config.criterions[k] in ['smoothl1', 'l1', 'l2', 'WingLoss', 'AWingLoss']: | |
loss = criterions[k](output[k], labels[k]) | |
elif config.criterions[k] in ["STARLoss", "STARLoss_v2"]: | |
_k = int(k / 3) if config.use_AAM else k | |
loss = criterions[k](heatmap[_k], labels[k]) | |
else: | |
assert NotImplementedError | |
loss = batch_weight * loss | |
sum_loss += config.loss_weights[k] * loss | |
loss = float(loss.data.cpu().item()) | |
losses.append(loss) | |
return losses, sum_loss | |
def forward_backward(config, train_loader, net_module, net, net_ema, criterions, optimizer, epoch): | |
train_model_time = AverageMeter() | |
ave_losses = [0] * config.label_num | |
net_module = net_module.float().to(config.device) | |
net_module.train(True) | |
dataset_size = len(train_loader.dataset) | |
batch_size = config.batch_size # train_loader.batch_size | |
batch_num = max(dataset_size / max(batch_size, 1), 1) | |
if config.logger is not None: | |
config.logger.info(config.note) | |
config.logger.info("Forward Backward process, Dataset size: %d, Batch size: %d" % (dataset_size, batch_size)) | |
iter_num = len(train_loader) | |
epoch_start_time = time.time() | |
if net_module != net: | |
train_loader.sampler.set_epoch(epoch) | |
for iter, sample in enumerate(train_loader): | |
iter_start_time = time.time() | |
# input | |
input = sample["data"].float().to(config.device, non_blocking=True) | |
# labels | |
labels = list() | |
if isinstance(sample["label"], list): | |
for label in sample["label"]: | |
label = label.float().to(config.device, non_blocking=True) | |
labels.append(label) | |
else: | |
label = sample["label"].float().to(config.device, non_blocking=True) | |
for k in range(label.shape[1]): | |
labels.append(label[:, k]) | |
labels = config.nstack * labels | |
# forward | |
output, heatmaps, landmarks = net_module(input) | |
# loss | |
losses, sum_loss = compute_loss(config, criterions, output, labels, heatmaps, landmarks) | |
ave_losses = list(map(sum, zip(ave_losses, losses))) | |
# backward | |
optimizer.zero_grad() | |
with torch.autograd.detect_anomaly(): | |
sum_loss.backward() | |
# torch.nn.utils.clip_grad_norm_(net_module.parameters(), 128.0) | |
optimizer.step() | |
if net_ema is not None: | |
accumulate_net(net_ema, net, 0.5 ** (config.batch_size / 10000.0)) | |
# accumulate_net(net_ema, net, 0.5 ** (8 / 10000.0)) | |
# output | |
train_model_time.update(time.time() - iter_start_time) | |
last_time = convert_secs2time(train_model_time.avg * (iter_num - iter - 1), True) | |
if iter % config.display_iteration == 0 or iter + 1 == len(train_loader): | |
if config.logger is not None: | |
losses_str = ' Average Loss: {:.6f}'.format(sum(losses) / len(losses)) | |
for k, loss in enumerate(losses): | |
losses_str += ', L{}: {:.3f}'.format(k, loss) | |
config.logger.info( | |
' -->>[{:03d}/{:03d}][{:03d}/{:03d}]'.format(epoch, config.max_epoch, iter, iter_num) \ | |
+ last_time + losses_str) | |
epoch_end_time = time.time() | |
epoch_total_time = epoch_end_time - epoch_start_time | |
epoch_load_data_time = epoch_total_time - train_model_time.sum | |
if config.logger is not None: | |
config.logger.info("Train/Epoch: %d/%d, Average total time cost per iteration in this epoch: %.6f" % ( | |
epoch, config.max_epoch, epoch_total_time / iter_num)) | |
config.logger.info("Train/Epoch: %d/%d, Average loading data time cost per iteration in this epoch: %.6f" % ( | |
epoch, config.max_epoch, epoch_load_data_time / iter_num)) | |
config.logger.info("Train/Epoch: %d/%d, Average training model time cost per iteration in this epoch: %.6f" % ( | |
epoch, config.max_epoch, train_model_time.avg)) | |
ave_losses = [loss / iter_num for loss in ave_losses] | |
if config.logger is not None: | |
config.logger.info("Train/Epoch: %d/%d, Average Loss in this epoch: %.6f" % ( | |
epoch, config.max_epoch, sum(ave_losses) / len(ave_losses))) | |
for k, ave_loss in enumerate(ave_losses): | |
if config.logger is not None: | |
config.logger.info("Train/Loss%03d in this epoch: %.6f" % (k, ave_loss)) | |
def accumulate_net(model1, model2, decay): | |
""" | |
operation: model1 = model1 * decay + model2 * (1 - decay) | |
""" | |
par1 = dict(model1.named_parameters()) | |
par2 = dict(model2.named_parameters()) | |
for k in par1.keys(): | |
par1[k].data.mul_(decay).add_( | |
other=par2[k].data.to(par1[k].data.device), | |
alpha=1 - decay) | |
par1 = dict(model1.named_buffers()) | |
par2 = dict(model2.named_buffers()) | |
for k in par1.keys(): | |
if par1[k].data.is_floating_point(): | |
par1[k].data.mul_(decay).add_( | |
other=par2[k].data.to(par1[k].data.device), | |
alpha=1 - decay) | |
else: | |
par1[k].data = par2[k].data.to(par1[k].data.device) | |
def save_model(config, epoch, net, net_ema, optimizer, scheduler, pytorch_model_path): | |
# save pytorch model | |
state = { | |
"net": net.state_dict(), | |
"optimizer": optimizer.state_dict(), | |
"scheduler": scheduler.state_dict(), | |
"epoch": epoch | |
} | |
if config.ema: | |
state["net_ema"] = net_ema.state_dict() | |
torch.save(state, pytorch_model_path) | |
if config.logger is not None: | |
config.logger.info("Epoch: %d/%d, model saved in this epoch" % (epoch, config.max_epoch)) | |