RIP-AV-su-lab / AV /Tools /warmup.py
weidai00's picture
Upload 72 files
6c0075d verified
import torch
from torch.optim.lr_scheduler import StepLR, ExponentialLR
class GradualWarmupScheduler(torch.optim.lr_scheduler._LRScheduler):
""" Gradually warm-up(increasing) learning rate in optimizer.
Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
Args:
optimizer (Optimizer): Wrapped optimizer.
multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
total_epoch: target learning rate is reached at total_epoch, gradually
after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
"""
def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
self.multiplier = multiplier
if self.multiplier < 1.:
raise ValueError('multiplier should be greater thant or equal to 1.')
self.total_epoch = total_epoch
self.after_scheduler = after_scheduler
self.finished = False
super(GradualWarmupScheduler, self).__init__(optimizer)
def get_lr(self):
if self.last_epoch > self.total_epoch:
if self.after_scheduler:
if not self.finished:
self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
self.finished = True
return self.after_scheduler.get_last_lr()
return [base_lr * self.multiplier for base_lr in self.base_lrs]
if self.multiplier == 1.0:
return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
else:
return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
def step(self, epoch=None, metrics=None):
if self.finished and self.after_scheduler:
if epoch is None:
self.after_scheduler.step()
else:
self.after_scheduler.step()
self._last_lr = self.after_scheduler.get_last_lr()
else:
return super(GradualWarmupScheduler, self).step(epoch)
if __name__ == '__main__':
model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))]
optim = torch.optim.Adam(model, 0.0002)
# scheduler_warmup is chained with schduler_steplr
scheduler_steplr = StepLR(optim, step_size=80, gamma=0.1)
scheduler_warmup = GradualWarmupScheduler(optim, multiplier=2, total_epoch=10, after_scheduler=scheduler_steplr)
# this zero gradient update is needed to avoid a warning message, issue #8.
optim.zero_grad()
optim.step()
for epoch in range(1, 20):
scheduler_warmup.step(epoch)
print(epoch, optim.param_groups[0]['lr'])
optim.step()