import torch from torch.optim.lr_scheduler import StepLR, ExponentialLR class GradualWarmupScheduler(torch.optim.lr_scheduler._LRScheduler): """ Gradually warm-up(increasing) learning rate in optimizer. Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. Args: optimizer (Optimizer): Wrapped optimizer. multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. total_epoch: target learning rate is reached at total_epoch, gradually after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) """ def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): self.multiplier = multiplier if self.multiplier < 1.: raise ValueError('multiplier should be greater thant or equal to 1.') self.total_epoch = total_epoch self.after_scheduler = after_scheduler self.finished = False super(GradualWarmupScheduler, self).__init__(optimizer) def get_lr(self): if self.last_epoch > self.total_epoch: if self.after_scheduler: if not self.finished: self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] self.finished = True return self.after_scheduler.get_last_lr() return [base_lr * self.multiplier for base_lr in self.base_lrs] if self.multiplier == 1.0: return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] else: return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] def step(self, epoch=None, metrics=None): if self.finished and self.after_scheduler: if epoch is None: self.after_scheduler.step() else: self.after_scheduler.step() self._last_lr = self.after_scheduler.get_last_lr() else: return super(GradualWarmupScheduler, self).step(epoch) if __name__ == '__main__': model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))] optim = torch.optim.Adam(model, 0.0002) # scheduler_warmup is chained with schduler_steplr scheduler_steplr = StepLR(optim, step_size=80, gamma=0.1) scheduler_warmup = GradualWarmupScheduler(optim, multiplier=2, total_epoch=10, after_scheduler=scheduler_steplr) # this zero gradient update is needed to avoid a warning message, issue #8. optim.zero_grad() optim.step() for epoch in range(1, 20): scheduler_warmup.step(epoch) print(epoch, optim.param_groups[0]['lr']) optim.step()