import torch import torch.nn as nn class EMA: def __init__(self, beta: float): super().__init__() self.beta = beta self.step = 0 def update_model_average(self, ema_model: nn.Module, current_model: nn.Module) -> None: for current_params, ema_model in zip(current_model.parameters(), ema_model.parameters()): old_weight, up_weight = ema_model.data, current_params.data ema_model.data = self.update_average(old_weight, up_weight) def update_average(self, old: torch.Tensor | None, new: torch.Tensor) -> torch.Tensor: if old is None: return new return old * self.beta + (1 - self.beta) * new def step_ema(self, ema_model: nn.Module, model: nn.Module, step_start_ema: int = 2000) -> None: if self.step < step_start_ema: self.reset_parameters(ema_model, model) self.step += 1 return self.update_model_average(ema_model, model) self.step += 1 def copy_to(self, ema_model: nn.Module, model: nn.Module) -> None: model.load_state_dict(ema_model.state_dict()) def reset_parameters(self, ema_model: nn.Module, model: nn.Module) -> None: ema_model.load_state_dict(model.state_dict())