from abc import abstractmethod import os import time import json import torch import torch.distributed as dist from torch.utils.data import DataLoader import numpy as np from torchvision import utils from torch.utils.tensorboard import SummaryWriter from .utils import * from ..utils.general_utils import * from ..utils.data_utils import recursive_to_device, cycle, ResumableSampler class Trainer: """ Base class for training. """ def __init__(self, models, dataset, *, output_dir, load_dir, step, max_steps, batch_size=None, batch_size_per_gpu=None, batch_split=None, optimizer={}, lr_scheduler=None, elastic=None, grad_clip=None, ema_rate=0.9999, fp16_mode='inflat_all', fp16_scale_growth=1e-3, finetune_ckpt=None, log_param_stats=False, prefetch_data=True, i_print=1000, i_log=500, i_sample=10000, i_save=10000, i_ddpcheck=10000, **kwargs ): assert batch_size is not None or batch_size_per_gpu is not None, 'Either batch_size or batch_size_per_gpu must be specified.' self.models = models self.dataset = dataset self.batch_split = batch_split if batch_split is not None else 1 self.max_steps = max_steps self.optimizer_config = optimizer self.lr_scheduler_config = lr_scheduler self.elastic_controller_config = elastic self.grad_clip = grad_clip self.ema_rate = [ema_rate] if isinstance(ema_rate, float) else ema_rate self.fp16_mode = fp16_mode self.fp16_scale_growth = fp16_scale_growth self.log_param_stats = log_param_stats self.prefetch_data = prefetch_data if self.prefetch_data: self._data_prefetched = None self.output_dir = output_dir self.i_print = i_print self.i_log = i_log self.i_sample = i_sample self.i_save = i_save self.i_ddpcheck = i_ddpcheck if dist.is_initialized(): # Multi-GPU params self.world_size = dist.get_world_size() self.rank = dist.get_rank() self.local_rank = dist.get_rank() % torch.cuda.device_count() self.is_master = self.rank == 0 else: # Single-GPU params self.world_size = 1 self.rank = 0 self.local_rank = 0 self.is_master = True self.batch_size = batch_size if batch_size_per_gpu is None else batch_size_per_gpu * self.world_size self.batch_size_per_gpu = batch_size_per_gpu if batch_size_per_gpu is not None else batch_size // self.world_size assert self.batch_size % self.world_size == 0, 'Batch size must be divisible by the number of GPUs.' assert self.batch_size_per_gpu % self.batch_split == 0, 'Batch size per GPU must be divisible by batch split.' self.init_models_and_more(**kwargs) self.prepare_dataloader(**kwargs) # Load checkpoint self.step = 0 if load_dir is not None and step is not None: self.load(load_dir, step) elif finetune_ckpt is not None: self.finetune_from(finetune_ckpt) if self.is_master: os.makedirs(os.path.join(self.output_dir, 'ckpts'), exist_ok=True) os.makedirs(os.path.join(self.output_dir, 'samples'), exist_ok=True) self.writer = SummaryWriter(os.path.join(self.output_dir, 'tb_logs')) if self.world_size > 1: self.check_ddp() if self.is_master: print('\n\nTrainer initialized.') print(self) @property def device(self): for _, model in self.models.items(): if hasattr(model, 'device'): return model.device return next(list(self.models.values())[0].parameters()).device @abstractmethod def init_models_and_more(self, **kwargs): """ Initialize models and more. """ pass def prepare_dataloader(self, **kwargs): """ Prepare dataloader. """ self.data_sampler = ResumableSampler( self.dataset, shuffle=True, ) self.dataloader = DataLoader( self.dataset, batch_size=self.batch_size_per_gpu, num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())), pin_memory=True, drop_last=True, persistent_workers=True, collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, sampler=self.data_sampler, ) self.data_iterator = cycle(self.dataloader) @abstractmethod def load(self, load_dir, step=0): """ Load a checkpoint. Should be called by all processes. """ pass @abstractmethod def save(self): """ Save a checkpoint. Should be called only by the rank 0 process. """ pass @abstractmethod def finetune_from(self, finetune_ckpt): """ Finetune from a checkpoint. Should be called by all processes. """ pass @abstractmethod def run_snapshot(self, num_samples, batch_size=4, verbose=False, **kwargs): """ Run a snapshot of the model. """ pass @torch.no_grad() def visualize_sample(self, sample): """ Convert a sample to an image. """ if hasattr(self.dataset, 'visualize_sample'): return self.dataset.visualize_sample(sample) else: return sample @torch.no_grad() def snapshot_dataset(self, num_samples=100): """ Sample images from the dataset. """ dataloader = torch.utils.data.DataLoader( self.dataset, batch_size=num_samples, num_workers=0, shuffle=True, collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, ) data = next(iter(dataloader)) data = recursive_to_device(data, self.device) vis = self.visualize_sample(data) if isinstance(vis, dict): save_cfg = [(f'dataset_{k}', v) for k, v in vis.items()] else: save_cfg = [('dataset', vis)] for name, image in save_cfg: utils.save_image( image, os.path.join(self.output_dir, 'samples', f'{name}.jpg'), nrow=int(np.sqrt(num_samples)), normalize=True, value_range=self.dataset.value_range, ) @torch.no_grad() def snapshot(self, suffix=None, num_samples=64, batch_size=4, verbose=False): """ Sample images from the model. NOTE: This function should be called by all processes. """ if self.is_master: print(f'\nSampling {num_samples} images...', end='') if suffix is None: suffix = f'step{self.step:07d}' # Assign tasks num_samples_per_process = int(np.ceil(num_samples / self.world_size)) samples = self.run_snapshot(num_samples_per_process, batch_size=batch_size, verbose=verbose) # Preprocess images for key in list(samples.keys()): if samples[key]['type'] == 'sample': vis = self.visualize_sample(samples[key]['value']) if isinstance(vis, dict): for k, v in vis.items(): samples[f'{key}_{k}'] = {'value': v, 'type': 'image'} del samples[key] else: samples[key] = {'value': vis, 'type': 'image'} # Gather results if self.world_size > 1: for key in samples.keys(): samples[key]['value'] = samples[key]['value'].contiguous() if self.is_master: all_images = [torch.empty_like(samples[key]['value']) for _ in range(self.world_size)] else: all_images = [] dist.gather(samples[key]['value'], all_images, dst=0) if self.is_master: samples[key]['value'] = torch.cat(all_images, dim=0)[:num_samples] # Save images if self.is_master: os.makedirs(os.path.join(self.output_dir, 'samples', suffix), exist_ok=True) for key in samples.keys(): if samples[key]['type'] == 'image': utils.save_image( samples[key]['value'], os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg'), nrow=int(np.sqrt(num_samples)), normalize=True, value_range=self.dataset.value_range, ) elif samples[key]['type'] == 'number': min = samples[key]['value'].min() max = samples[key]['value'].max() images = (samples[key]['value'] - min) / (max - min) images = utils.make_grid( images, nrow=int(np.sqrt(num_samples)), normalize=False, ) save_image_with_notes( images, os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg'), notes=f'{key} min: {min}, max: {max}', ) if self.is_master: print(' Done.') @abstractmethod def update_ema(self): """ Update exponential moving average. Should only be called by the rank 0 process. """ pass @abstractmethod def check_ddp(self): """ Check if DDP is working properly. Should be called by all process. """ pass @abstractmethod def training_losses(**mb_data): """ Compute training losses. """ pass def load_data(self): """ Load data. """ if self.prefetch_data: if self._data_prefetched is None: self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True) data = self._data_prefetched self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True) else: data = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True) # if the data is a dict, we need to split it into multiple dicts with batch_size_per_gpu if isinstance(data, dict): if self.batch_split == 1: data_list = [data] else: batch_size = list(data.values())[0].shape[0] data_list = [ {k: v[i * batch_size // self.batch_split:(i + 1) * batch_size // self.batch_split] for k, v in data.items()} for i in range(self.batch_split) ] elif isinstance(data, list): data_list = data else: raise ValueError('Data must be a dict or a list of dicts.') return data_list @abstractmethod def run_step(self, data_list): """ Run a training step. """ pass def run(self): """ Run training. """ if self.is_master: print('\nStarting training...') self.snapshot_dataset() if self.step == 0: self.snapshot(suffix='init') else: # resume self.snapshot(suffix=f'resume_step{self.step:07d}') log = [] time_last_print = 0.0 time_elapsed = 0.0 while self.step < self.max_steps: time_start = time.time() data_list = self.load_data() step_log = self.run_step(data_list) time_end = time.time() time_elapsed += time_end - time_start self.step += 1 # Print progress if self.is_master and self.step % self.i_print == 0: speed = self.i_print / (time_elapsed - time_last_print) * 3600 columns = [ f'Step: {self.step}/{self.max_steps} ({self.step / self.max_steps * 100:.2f}%)', f'Elapsed: {time_elapsed / 3600:.2f} h', f'Speed: {speed:.2f} steps/h', f'ETA: {(self.max_steps - self.step) / speed:.2f} h', ] print(' | '.join([c.ljust(25) for c in columns]), flush=True) time_last_print = time_elapsed # Check ddp if self.world_size > 1 and self.i_ddpcheck is not None and self.step % self.i_ddpcheck == 0: self.check_ddp() # Sample images if self.step % self.i_sample == 0: self.snapshot() if self.is_master: log.append((self.step, {})) # Log time log[-1][1]['time'] = { 'step': time_end - time_start, 'elapsed': time_elapsed, } # Log losses if step_log is not None: log[-1][1].update(step_log) # Log scale if self.fp16_mode == 'amp': log[-1][1]['scale'] = self.scaler.get_scale() elif self.fp16_mode == 'inflat_all': log[-1][1]['log_scale'] = self.log_scale # Save log if self.step % self.i_log == 0: ## save to log file log_str = '\n'.join([ f'{step}: {json.dumps(log)}' for step, log in log ]) with open(os.path.join(self.output_dir, 'log.txt'), 'a') as log_file: log_file.write(log_str + '\n') # show with mlflow log_show = [l for _, l in log if not dict_any(l, lambda x: np.isnan(x))] log_show = dict_reduce(log_show, lambda x: np.mean(x)) log_show = dict_flatten(log_show, sep='/') for key, value in log_show.items(): self.writer.add_scalar(key, value, self.step) log = [] # Save checkpoint if self.step % self.i_save == 0: self.save() if self.is_master: self.snapshot(suffix='final') self.writer.close() print('Training finished.') def profile(self, wait=2, warmup=3, active=5): """ Profile the training loop. """ with torch.profiler.profile( schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1), on_trace_ready=torch.profiler.tensorboard_trace_handler(os.path.join(self.output_dir, 'profile')), profile_memory=True, with_stack=True, ) as prof: for _ in range(wait + warmup + active): self.run_step() prof.step()