vfontech's picture
Uploading the app
587665f verified
raw
history blame
1.26 kB
import torch
import torch.nn as nn
class EMA:
def __init__(self, beta: float):
super().__init__()
self.beta = beta
self.step = 0
def update_model_average(self, ema_model: nn.Module, current_model: nn.Module) -> None:
for current_params, ema_model in zip(current_model.parameters(), ema_model.parameters()):
old_weight, up_weight = ema_model.data, current_params.data
ema_model.data = self.update_average(old_weight, up_weight)
def update_average(self, old: torch.Tensor | None, new: torch.Tensor) -> torch.Tensor:
if old is None:
return new
return old * self.beta + (1 - self.beta) * new
def step_ema(self, ema_model: nn.Module, model: nn.Module, step_start_ema: int = 2000) -> None:
if self.step < step_start_ema:
self.reset_parameters(ema_model, model)
self.step += 1
return
self.update_model_average(ema_model, model)
self.step += 1
def copy_to(self, ema_model: nn.Module, model: nn.Module) -> None:
model.load_state_dict(ema_model.state_dict())
def reset_parameters(self, ema_model: nn.Module, model: nn.Module) -> None:
ema_model.load_state_dict(model.state_dict())