File size: 593 Bytes
5eff22e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch.nn as nn
import torch
import torch.nn.functional as F


class ScaledL2Norm(nn.Module):
    def __init__(self, in_channels, initial_scale):
        super(ScaledL2Norm, self).__init__()
        self.in_channels = in_channels
        self.scale = nn.Parameter(torch.Tensor(in_channels))
        self.initial_scale = initial_scale
        self.reset_parameters()

    def forward(self, x):
        return (F.normalize(x, p=2, dim=1)
                * self.scale.unsqueeze(0).unsqueeze(2).unsqueeze(3))

    def reset_parameters(self):
        self.scale.data.fill_(self.initial_scale)