Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) | |
Copyright(c) 2023 lyuwenyu. All Rights Reserved. | |
""" | |
from torch.optim.lr_scheduler import LRScheduler | |
from ..core import register | |
class Warmup(object): | |
def __init__( | |
self, lr_scheduler: LRScheduler, warmup_duration: int, last_step: int = -1 | |
) -> None: | |
self.lr_scheduler = lr_scheduler | |
self.warmup_end_values = [pg["lr"] for pg in lr_scheduler.optimizer.param_groups] | |
self.last_step = last_step | |
self.warmup_duration = warmup_duration | |
self.step() | |
def state_dict(self): | |
return {k: v for k, v in self.__dict__.items() if k != "lr_scheduler"} | |
def load_state_dict(self, state_dict): | |
self.__dict__.update(state_dict) | |
def get_warmup_factor(self, step, **kwargs): | |
raise NotImplementedError | |
def step( | |
self, | |
): | |
self.last_step += 1 | |
if self.last_step >= self.warmup_duration: | |
return | |
factor = self.get_warmup_factor(self.last_step) | |
for i, pg in enumerate(self.lr_scheduler.optimizer.param_groups): | |
pg["lr"] = factor * self.warmup_end_values[i] | |
def finished( | |
self, | |
): | |
if self.last_step >= self.warmup_duration: | |
return True | |
return False | |
class LinearWarmup(Warmup): | |
def __init__( | |
self, lr_scheduler: LRScheduler, warmup_duration: int, last_step: int = -1 | |
) -> None: | |
super().__init__(lr_scheduler, warmup_duration, last_step) | |
def get_warmup_factor(self, step): | |
return min(1.0, (step + 1) / self.warmup_duration) | |