|
import torch |
|
import torch.nn as nn |
|
|
|
class EMA: |
|
def __init__(self, beta: float): |
|
super().__init__() |
|
self.beta = beta |
|
self.step = 0 |
|
|
|
def update_model_average(self, ema_model: nn.Module, current_model: nn.Module) -> None: |
|
for current_params, ema_model in zip(current_model.parameters(), ema_model.parameters()): |
|
old_weight, up_weight = ema_model.data, current_params.data |
|
ema_model.data = self.update_average(old_weight, up_weight) |
|
|
|
def update_average(self, old: torch.Tensor | None, new: torch.Tensor) -> torch.Tensor: |
|
if old is None: |
|
return new |
|
return old * self.beta + (1 - self.beta) * new |
|
|
|
def step_ema(self, ema_model: nn.Module, model: nn.Module, step_start_ema: int = 2000) -> None: |
|
if self.step < step_start_ema: |
|
self.reset_parameters(ema_model, model) |
|
self.step += 1 |
|
return |
|
self.update_model_average(ema_model, model) |
|
self.step += 1 |
|
|
|
def copy_to(self, ema_model: nn.Module, model: nn.Module) -> None: |
|
model.load_state_dict(ema_model.state_dict()) |
|
|
|
def reset_parameters(self, ema_model: nn.Module, model: nn.Module) -> None: |
|
ema_model.load_state_dict(model.state_dict()) |