kairunwen's picture
Update Code
57746f1
"""
Scheduler
Author: Xiaoyang Wu ([email protected])
Please cite our work if the code is helpful to you.
"""
import torch.optim.lr_scheduler as lr_scheduler
from .registry import Registry
SCHEDULERS = Registry("schedulers")
@SCHEDULERS.register_module()
class MultiStepLR(lr_scheduler.MultiStepLR):
def __init__(
self,
optimizer,
milestones,
total_steps,
gamma=0.1,
last_epoch=-1,
verbose=False,
):
super().__init__(
optimizer=optimizer,
milestones=[rate * total_steps for rate in milestones],
gamma=gamma,
last_epoch=last_epoch,
verbose=verbose,
)
@SCHEDULERS.register_module()
class MultiStepWithWarmupLR(lr_scheduler.LambdaLR):
def __init__(
self,
optimizer,
milestones,
total_steps,
gamma=0.1,
warmup_rate=0.05,
warmup_scale=1e-6,
last_epoch=-1,
verbose=False,
):
milestones = [rate * total_steps for rate in milestones]
def multi_step_with_warmup(s):
factor = 1.0
for i in range(len(milestones)):
if s < milestones[i]:
break
factor *= gamma
if s <= warmup_rate * total_steps:
warmup_coefficient = 1 - (1 - s / warmup_rate / total_steps) * (
1 - warmup_scale
)
else:
warmup_coefficient = 1.0
return warmup_coefficient * factor
super().__init__(
optimizer=optimizer,
lr_lambda=multi_step_with_warmup,
last_epoch=last_epoch,
verbose=verbose,
)
@SCHEDULERS.register_module()
class PolyLR(lr_scheduler.LambdaLR):
def __init__(self, optimizer, total_steps, power=0.9, last_epoch=-1, verbose=False):
super().__init__(
optimizer=optimizer,
lr_lambda=lambda s: (1 - s / (total_steps + 1)) ** power,
last_epoch=last_epoch,
verbose=verbose,
)
@SCHEDULERS.register_module()
class ExpLR(lr_scheduler.LambdaLR):
def __init__(self, optimizer, total_steps, gamma=0.9, last_epoch=-1, verbose=False):
super().__init__(
optimizer=optimizer,
lr_lambda=lambda s: gamma ** (s / total_steps),
last_epoch=last_epoch,
verbose=verbose,
)
@SCHEDULERS.register_module()
class CosineAnnealingLR(lr_scheduler.CosineAnnealingLR):
def __init__(self, optimizer, total_steps, eta_min=0, last_epoch=-1, verbose=False):
super().__init__(
optimizer=optimizer,
T_max=total_steps,
eta_min=eta_min,
last_epoch=last_epoch,
verbose=verbose,
)
@SCHEDULERS.register_module()
class OneCycleLR(lr_scheduler.OneCycleLR):
r"""
torch.optim.lr_scheduler.OneCycleLR, Block total_steps
"""
def __init__(
self,
optimizer,
max_lr,
total_steps=None,
pct_start=0.3,
anneal_strategy="cos",
cycle_momentum=True,
base_momentum=0.85,
max_momentum=0.95,
div_factor=25.0,
final_div_factor=1e4,
three_phase=False,
last_epoch=-1,
verbose=False,
):
super().__init__(
optimizer=optimizer,
max_lr=max_lr,
total_steps=total_steps,
pct_start=pct_start,
anneal_strategy=anneal_strategy,
cycle_momentum=cycle_momentum,
base_momentum=base_momentum,
max_momentum=max_momentum,
div_factor=div_factor,
final_div_factor=final_div_factor,
three_phase=three_phase,
last_epoch=last_epoch,
verbose=verbose,
)
def build_scheduler(cfg, optimizer):
cfg.optimizer = optimizer
return SCHEDULERS.build(cfg=cfg)