|
""" |
|
Scheduler |
|
|
|
Author: Xiaoyang Wu ([email protected]) |
|
Please cite our work if the code is helpful to you. |
|
""" |
|
|
|
import torch.optim.lr_scheduler as lr_scheduler |
|
from .registry import Registry |
|
|
|
SCHEDULERS = Registry("schedulers") |
|
|
|
|
|
@SCHEDULERS.register_module() |
|
class MultiStepLR(lr_scheduler.MultiStepLR): |
|
def __init__( |
|
self, |
|
optimizer, |
|
milestones, |
|
total_steps, |
|
gamma=0.1, |
|
last_epoch=-1, |
|
verbose=False, |
|
): |
|
super().__init__( |
|
optimizer=optimizer, |
|
milestones=[rate * total_steps for rate in milestones], |
|
gamma=gamma, |
|
last_epoch=last_epoch, |
|
verbose=verbose, |
|
) |
|
|
|
|
|
@SCHEDULERS.register_module() |
|
class MultiStepWithWarmupLR(lr_scheduler.LambdaLR): |
|
def __init__( |
|
self, |
|
optimizer, |
|
milestones, |
|
total_steps, |
|
gamma=0.1, |
|
warmup_rate=0.05, |
|
warmup_scale=1e-6, |
|
last_epoch=-1, |
|
verbose=False, |
|
): |
|
milestones = [rate * total_steps for rate in milestones] |
|
|
|
def multi_step_with_warmup(s): |
|
factor = 1.0 |
|
for i in range(len(milestones)): |
|
if s < milestones[i]: |
|
break |
|
factor *= gamma |
|
|
|
if s <= warmup_rate * total_steps: |
|
warmup_coefficient = 1 - (1 - s / warmup_rate / total_steps) * ( |
|
1 - warmup_scale |
|
) |
|
else: |
|
warmup_coefficient = 1.0 |
|
return warmup_coefficient * factor |
|
|
|
super().__init__( |
|
optimizer=optimizer, |
|
lr_lambda=multi_step_with_warmup, |
|
last_epoch=last_epoch, |
|
verbose=verbose, |
|
) |
|
|
|
|
|
@SCHEDULERS.register_module() |
|
class PolyLR(lr_scheduler.LambdaLR): |
|
def __init__(self, optimizer, total_steps, power=0.9, last_epoch=-1, verbose=False): |
|
super().__init__( |
|
optimizer=optimizer, |
|
lr_lambda=lambda s: (1 - s / (total_steps + 1)) ** power, |
|
last_epoch=last_epoch, |
|
verbose=verbose, |
|
) |
|
|
|
|
|
@SCHEDULERS.register_module() |
|
class ExpLR(lr_scheduler.LambdaLR): |
|
def __init__(self, optimizer, total_steps, gamma=0.9, last_epoch=-1, verbose=False): |
|
super().__init__( |
|
optimizer=optimizer, |
|
lr_lambda=lambda s: gamma ** (s / total_steps), |
|
last_epoch=last_epoch, |
|
verbose=verbose, |
|
) |
|
|
|
|
|
@SCHEDULERS.register_module() |
|
class CosineAnnealingLR(lr_scheduler.CosineAnnealingLR): |
|
def __init__(self, optimizer, total_steps, eta_min=0, last_epoch=-1, verbose=False): |
|
super().__init__( |
|
optimizer=optimizer, |
|
T_max=total_steps, |
|
eta_min=eta_min, |
|
last_epoch=last_epoch, |
|
verbose=verbose, |
|
) |
|
|
|
|
|
@SCHEDULERS.register_module() |
|
class OneCycleLR(lr_scheduler.OneCycleLR): |
|
r""" |
|
torch.optim.lr_scheduler.OneCycleLR, Block total_steps |
|
""" |
|
|
|
def __init__( |
|
self, |
|
optimizer, |
|
max_lr, |
|
total_steps=None, |
|
pct_start=0.3, |
|
anneal_strategy="cos", |
|
cycle_momentum=True, |
|
base_momentum=0.85, |
|
max_momentum=0.95, |
|
div_factor=25.0, |
|
final_div_factor=1e4, |
|
three_phase=False, |
|
last_epoch=-1, |
|
verbose=False, |
|
): |
|
super().__init__( |
|
optimizer=optimizer, |
|
max_lr=max_lr, |
|
total_steps=total_steps, |
|
pct_start=pct_start, |
|
anneal_strategy=anneal_strategy, |
|
cycle_momentum=cycle_momentum, |
|
base_momentum=base_momentum, |
|
max_momentum=max_momentum, |
|
div_factor=div_factor, |
|
final_div_factor=final_div_factor, |
|
three_phase=three_phase, |
|
last_epoch=last_epoch, |
|
verbose=verbose, |
|
) |
|
|
|
|
|
def build_scheduler(cfg, optimizer): |
|
cfg.optimizer = optimizer |
|
return SCHEDULERS.build(cfg=cfg) |
|
|