File size: 1,260 Bytes
587665f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import torch.nn as nn

class EMA:
    def __init__(self, beta: float):
        super().__init__()
        self.beta = beta
        self.step = 0

    def update_model_average(self, ema_model: nn.Module, current_model: nn.Module) -> None:
        for current_params, ema_model in zip(current_model.parameters(), ema_model.parameters()):
            old_weight, up_weight = ema_model.data, current_params.data
            ema_model.data = self.update_average(old_weight, up_weight)

    def update_average(self, old: torch.Tensor | None, new: torch.Tensor) -> torch.Tensor:
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

    def step_ema(self, ema_model: nn.Module, model: nn.Module, step_start_ema: int = 2000) -> None:
        if self.step < step_start_ema:
            self.reset_parameters(ema_model, model)
            self.step += 1
            return
        self.update_model_average(ema_model, model)
        self.step += 1

    def copy_to(self, ema_model: nn.Module, model: nn.Module) -> None:
        model.load_state_dict(ema_model.state_dict())

    def reset_parameters(self, ema_model: nn.Module, model: nn.Module) -> None:
        ema_model.load_state_dict(model.state_dict())