|
|
|
import math |
|
from typing import Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
from mmengine.model import ExponentialMovingAverage |
|
from torch import Tensor |
|
|
|
from mmdet.registry import MODELS |
|
|
|
|
|
@MODELS.register_module() |
|
class ExpMomentumEMA(ExponentialMovingAverage): |
|
"""Exponential moving average (EMA) with exponential momentum strategy, |
|
which is used in YOLOX. |
|
|
|
Args: |
|
model (nn.Module): The model to be averaged. |
|
momentum (float): The momentum used for updating ema parameter. |
|
Ema's parameter are updated with the formula: |
|
`averaged_param = (1-momentum) * averaged_param + momentum * |
|
source_param`. Defaults to 0.0002. |
|
gamma (int): Use a larger momentum early in training and gradually |
|
annealing to a smaller value to update the ema model smoothly. The |
|
momentum is calculated as |
|
`(1 - momentum) * exp(-(1 + steps) / gamma) + momentum`. |
|
Defaults to 2000. |
|
interval (int): Interval between two updates. Defaults to 1. |
|
device (torch.device, optional): If provided, the averaged model will |
|
be stored on the :attr:`device`. Defaults to None. |
|
update_buffers (bool): if True, it will compute running averages for |
|
both the parameters and the buffers of the model. Defaults to |
|
False. |
|
""" |
|
|
|
def __init__(self, |
|
model: nn.Module, |
|
momentum: float = 0.0002, |
|
gamma: int = 2000, |
|
interval=1, |
|
device: Optional[torch.device] = None, |
|
update_buffers: bool = False) -> None: |
|
super().__init__( |
|
model=model, |
|
momentum=momentum, |
|
interval=interval, |
|
device=device, |
|
update_buffers=update_buffers) |
|
assert gamma > 0, f'gamma must be greater than 0, but got {gamma}' |
|
self.gamma = gamma |
|
|
|
def avg_func(self, averaged_param: Tensor, source_param: Tensor, |
|
steps: int) -> None: |
|
"""Compute the moving average of the parameters using the exponential |
|
momentum strategy. |
|
|
|
Args: |
|
averaged_param (Tensor): The averaged parameters. |
|
source_param (Tensor): The source parameters. |
|
steps (int): The number of times the parameters have been |
|
updated. |
|
""" |
|
momentum = (1 - self.momentum) * math.exp( |
|
-float(1 + steps) / self.gamma) + self.momentum |
|
averaged_param.mul_(1 - momentum).add_(source_param, alpha=momentum) |
|
|