Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,680 Bytes
e85fecb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
"""
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
from torch.optim.lr_scheduler import LRScheduler
from ..core import register
class Warmup(object):
def __init__(
self, lr_scheduler: LRScheduler, warmup_duration: int, last_step: int = -1
) -> None:
self.lr_scheduler = lr_scheduler
self.warmup_end_values = [pg["lr"] for pg in lr_scheduler.optimizer.param_groups]
self.last_step = last_step
self.warmup_duration = warmup_duration
self.step()
def state_dict(self):
return {k: v for k, v in self.__dict__.items() if k != "lr_scheduler"}
def load_state_dict(self, state_dict):
self.__dict__.update(state_dict)
def get_warmup_factor(self, step, **kwargs):
raise NotImplementedError
def step(
self,
):
self.last_step += 1
if self.last_step >= self.warmup_duration:
return
factor = self.get_warmup_factor(self.last_step)
for i, pg in enumerate(self.lr_scheduler.optimizer.param_groups):
pg["lr"] = factor * self.warmup_end_values[i]
def finished(
self,
):
if self.last_step >= self.warmup_duration:
return True
return False
@register()
class LinearWarmup(Warmup):
def __init__(
self, lr_scheduler: LRScheduler, warmup_duration: int, last_step: int = -1
) -> None:
super().__init__(lr_scheduler, warmup_duration, last_step)
def get_warmup_factor(self, step):
return min(1.0, (step + 1) / self.warmup_duration)
|