Spaces:
Running
Running
File size: 7,484 Bytes
e3641b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import os.path as osp
import torch
import torch.nn as nn
from vit_models.losses import JointsMSELoss
from vit_models.optimizer import LayerDecayOptimizer
from torch.nn.parallel import DataParallel, DistributedDataParallel
from torch.nn.utils import clip_grad_norm_
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR, MultiStepLR
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
from time import time
from vit_utils.dist_util import get_dist_info, init_dist
from vit_utils.logging import get_root_logger
@torch.no_grad()
def valid_model(model: nn.Module, dataloaders: DataLoader, criterion: nn.Module, cfg: dict) -> None:
total_loss = 0
total_metric = 0
model.eval()
for dataloader in dataloaders:
for batch_idx, batch in enumerate(dataloader):
images, targets, target_weights, __ = batch
images = images.to('cuda')
targets = targets.to('cuda')
target_weights = target_weights.to('cuda')
outputs = model(images)
loss = criterion(outputs, targets, target_weights)
total_loss += loss.item()
avg_loss = total_loss/(len(dataloader)*len(dataloaders))
return avg_loss
def train_model(model: nn.Module, datasets_train: Dataset, datasets_valid: Dataset, cfg: dict, distributed: bool, validate: bool, timestamp: str, meta: dict) -> None:
logger = get_root_logger()
# Prepare data loaders
datasets_train = datasets_train if isinstance(datasets_train, (list, tuple)) else [datasets_train]
datasets_valid = datasets_valid if isinstance(datasets_valid, (list, tuple)) else [datasets_valid]
if distributed:
samplers_train = [DistributedSampler(ds, num_replicas=len(cfg.gpu_ids), rank=torch.cuda.current_device(), shuffle=True, drop_last=False) for ds in datasets_train]
samplers_valid = [DistributedSampler(ds, num_replicas=len(cfg.gpu_ids), rank=torch.cuda.current_device(), shuffle=False, drop_last=False) for ds in datasets_valid]
else:
samplers_train = [None for ds in datasets_train]
samplers_valid = [None for ds in datasets_valid]
dataloaders_train = [DataLoader(ds, batch_size=cfg.data['samples_per_gpu'], shuffle=True, sampler=sampler, num_workers=cfg.data['workers_per_gpu'], pin_memory=False) for ds, sampler in zip(datasets_train, samplers_train)]
dataloaders_valid = [DataLoader(ds, batch_size=cfg.data['samples_per_gpu'], shuffle=False, sampler=sampler, num_workers=cfg.data['workers_per_gpu'], pin_memory=False) for ds, sampler in zip(datasets_valid, samplers_valid)]
# put model on gpus
if distributed:
find_unused_parameters = cfg.get('find_unused_parameters', False)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
model = DistributedDataParallel(
module=model,
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
model = DataParallel(model, device_ids=cfg.gpu_ids)
# Loss function
criterion = JointsMSELoss(use_target_weight=cfg.model['keypoint_head']['loss_keypoint']['use_target_weight'])
# Optimizer
optimizer = AdamW(model.parameters(), lr=cfg.optimizer['lr'], betas=cfg.optimizer['betas'], weight_decay=cfg.optimizer['weight_decay'])
# Layer-wise learning rate decay
lr_mult = [cfg.optimizer['paramwise_cfg']['layer_decay_rate']] * cfg.optimizer['paramwise_cfg']['num_layers']
layerwise_optimizer = LayerDecayOptimizer(optimizer, lr_mult)
# Learning rate scheduler (MultiStepLR)
milestones = cfg.lr_config['step']
gamma = 0.1
scheduler = MultiStepLR(optimizer, milestones, gamma)
# Warm-up scheduler
num_warmup_steps = cfg.lr_config['warmup_iters'] # Number of warm-up steps
warmup_factor = cfg.lr_config['warmup_ratio'] # Initial learning rate = warmup_factor * learning_rate
warmup_scheduler = LambdaLR(
optimizer,
lr_lambda=lambda step: warmup_factor + (1.0 - warmup_factor) * step / num_warmup_steps
)
# AMP setting
if cfg.use_amp:
logger.info("Using Automatic Mixed Precision (AMP) training...")
# Create a GradScaler object for FP16 training
scaler = GradScaler()
# Logging config
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f'''\n
#========= [Train Configs] =========#
# - Num GPUs: {len(cfg.gpu_ids)}
# - Batch size (per gpu): {cfg.data['samples_per_gpu']}
# - LR: {cfg.optimizer['lr']: .6f}
# - Num params: {total_params:,d}
# - AMP: {cfg.use_amp}
#===================================#
''')
global_step = 0
for dataloader in dataloaders_train:
for epoch in range(cfg.total_epochs):
model.train()
train_pbar = tqdm(dataloader)
total_loss = 0
tic = time()
for batch_idx, batch in enumerate(train_pbar):
layerwise_optimizer.zero_grad()
images, targets, target_weights, __ = batch
images = images.to('cuda')
targets = targets.to('cuda')
target_weights = target_weights.to('cuda')
if cfg.use_amp:
with autocast():
outputs = model(images)
loss = criterion(outputs, targets, target_weights)
scaler.scale(loss).backward()
clip_grad_norm_(model.parameters(), **cfg.optimizer_config['grad_clip'])
scaler.step(layerwise_optimizer)
scaler.update()
else:
outputs = model(images)
loss = criterion(outputs, targets, target_weights)
loss.backward()
clip_grad_norm_(model.parameters(), **cfg.optimizer_config['grad_clip'])
layerwise_optimizer.step()
if global_step < num_warmup_steps:
warmup_scheduler.step()
global_step += 1
total_loss += loss.item()
train_pbar.set_description(f"🏋️> Epoch [{str(epoch).zfill(3)}/{str(cfg.total_epochs).zfill(3)}] | Loss {loss.item():.4f} | LR {optimizer.param_groups[0]['lr']:.6f} | Step")
scheduler.step()
avg_loss_train = total_loss/len(dataloader)
logger.info(f"[Summary-train] Epoch [{str(epoch).zfill(3)}/{str(cfg.total_epochs).zfill(3)}] | Average Loss (train) {avg_loss_train:.4f} --- {time()-tic:.5f} sec. elapsed")
ckpt_name = f"epoch{str(epoch).zfill(3)}.pth"
ckpt_path = osp.join(cfg.work_dir, ckpt_name)
torch.save(model.module.state_dict(), ckpt_path)
# validation
if validate:
tic2 = time()
avg_loss_valid = valid_model(model, dataloaders_valid, criterion, cfg)
logger.info(f"[Summary-valid] Epoch [{str(epoch).zfill(3)}/{str(cfg.total_epochs).zfill(3)}] | Average Loss (valid) {avg_loss_valid:.4f} --- {time()-tic2:.5f} sec. elapsed")
|