D-FINE / src /optim /warmup.py
developer0hye's picture
Upload 76 files
e85fecb verified
"""
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
from torch.optim.lr_scheduler import LRScheduler
from ..core import register
class Warmup(object):
def __init__(
self, lr_scheduler: LRScheduler, warmup_duration: int, last_step: int = -1
) -> None:
self.lr_scheduler = lr_scheduler
self.warmup_end_values = [pg["lr"] for pg in lr_scheduler.optimizer.param_groups]
self.last_step = last_step
self.warmup_duration = warmup_duration
self.step()
def state_dict(self):
return {k: v for k, v in self.__dict__.items() if k != "lr_scheduler"}
def load_state_dict(self, state_dict):
self.__dict__.update(state_dict)
def get_warmup_factor(self, step, **kwargs):
raise NotImplementedError
def step(
self,
):
self.last_step += 1
if self.last_step >= self.warmup_duration:
return
factor = self.get_warmup_factor(self.last_step)
for i, pg in enumerate(self.lr_scheduler.optimizer.param_groups):
pg["lr"] = factor * self.warmup_end_values[i]
def finished(
self,
):
if self.last_step >= self.warmup_duration:
return True
return False
@register()
class LinearWarmup(Warmup):
def __init__(
self, lr_scheduler: LRScheduler, warmup_duration: int, last_step: int = -1
) -> None:
super().__init__(lr_scheduler, warmup_duration, last_step)
def get_warmup_factor(self, step):
return min(1.0, (step + 1) / self.warmup_duration)