Spaces:
Running
Running
import torch | |
from torch.optim.lr_scheduler import StepLR, ExponentialLR | |
class GradualWarmupScheduler(torch.optim.lr_scheduler._LRScheduler): | |
""" Gradually warm-up(increasing) learning rate in optimizer. | |
Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. | |
Args: | |
optimizer (Optimizer): Wrapped optimizer. | |
multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. | |
total_epoch: target learning rate is reached at total_epoch, gradually | |
after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) | |
""" | |
def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): | |
self.multiplier = multiplier | |
if self.multiplier < 1.: | |
raise ValueError('multiplier should be greater thant or equal to 1.') | |
self.total_epoch = total_epoch | |
self.after_scheduler = after_scheduler | |
self.finished = False | |
super(GradualWarmupScheduler, self).__init__(optimizer) | |
def get_lr(self): | |
if self.last_epoch > self.total_epoch: | |
if self.after_scheduler: | |
if not self.finished: | |
self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] | |
self.finished = True | |
return self.after_scheduler.get_last_lr() | |
return [base_lr * self.multiplier for base_lr in self.base_lrs] | |
if self.multiplier == 1.0: | |
return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] | |
else: | |
return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] | |
def step(self, epoch=None, metrics=None): | |
if self.finished and self.after_scheduler: | |
if epoch is None: | |
self.after_scheduler.step() | |
else: | |
self.after_scheduler.step() | |
self._last_lr = self.after_scheduler.get_last_lr() | |
else: | |
return super(GradualWarmupScheduler, self).step(epoch) | |
if __name__ == '__main__': | |
model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))] | |
optim = torch.optim.Adam(model, 0.0002) | |
# scheduler_warmup is chained with schduler_steplr | |
scheduler_steplr = StepLR(optim, step_size=80, gamma=0.1) | |
scheduler_warmup = GradualWarmupScheduler(optim, multiplier=2, total_epoch=10, after_scheduler=scheduler_steplr) | |
# this zero gradient update is needed to avoid a warning message, issue #8. | |
optim.zero_grad() | |
optim.step() | |
for epoch in range(1, 20): | |
scheduler_warmup.step(epoch) | |
print(epoch, optim.param_groups[0]['lr']) | |
optim.step() |