File size: 5,038 Bytes
59cb088
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
"""Activation functions."""

import torch
import torch.nn as nn
import torch.nn.functional as F


class SiLU(nn.Module):
    """Applies the Sigmoid-weighted Linear Unit (SiLU) activation function, also known as Swish."""

    @staticmethod
    def forward(x):
        """
        Applies the Sigmoid-weighted Linear Unit (SiLU) activation function.

        https://arxiv.org/pdf/1606.08415.pdf.
        """
        return x * torch.sigmoid(x)


class Hardswish(nn.Module):
    """Applies the Hardswish activation function, which is efficient for mobile and embedded devices."""

    @staticmethod
    def forward(x):
        """
        Applies the Hardswish activation function, compatible with TorchScript, CoreML, and ONNX.

        Equivalent to x * F.hardsigmoid(x)
        """
        return x * F.hardtanh(x + 3, 0.0, 6.0) / 6.0  # for TorchScript, CoreML and ONNX


class Mish(nn.Module):
    """Mish activation https://github.com/digantamisra98/Mish."""

    @staticmethod
    def forward(x):
        """Applies the Mish activation function, a smooth alternative to ReLU."""
        return x * F.softplus(x).tanh()


class MemoryEfficientMish(nn.Module):
    """Efficiently applies the Mish activation function using custom autograd for reduced memory usage."""

    class F(torch.autograd.Function):
        """Implements a custom autograd function for memory-efficient Mish activation."""

        @staticmethod
        def forward(ctx, x):
            """Applies the Mish activation function, a smooth ReLU alternative, to the input tensor `x`."""
            ctx.save_for_backward(x)
            return x.mul(torch.tanh(F.softplus(x)))  # x * tanh(ln(1 + exp(x)))

        @staticmethod
        def backward(ctx, grad_output):
            """Computes the gradient of the Mish activation function with respect to input `x`."""
            x = ctx.saved_tensors[0]
            sx = torch.sigmoid(x)
            fx = F.softplus(x).tanh()
            return grad_output * (fx + x * sx * (1 - fx * fx))

    def forward(self, x):
        """Applies the Mish activation function to the input tensor `x`."""
        return self.F.apply(x)


class FReLU(nn.Module):
    """FReLU activation https://arxiv.org/abs/2007.11824."""

    def __init__(self, c1, k=3):  # ch_in, kernel
        """Initializes FReLU activation with channel `c1` and kernel size `k`."""
        super().__init__()
        self.conv = nn.Conv2d(c1, c1, k, 1, 1, groups=c1, bias=False)
        self.bn = nn.BatchNorm2d(c1)

    def forward(self, x):
        """
        Applies FReLU activation with max operation between input and BN-convolved input.

        https://arxiv.org/abs/2007.11824
        """
        return torch.max(x, self.bn(self.conv(x)))


class AconC(nn.Module):
    """
    ACON activation (activate or not) function.

    AconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is a learnable parameter
    See "Activate or Not: Learning Customized Activation" https://arxiv.org/pdf/2009.04759.pdf.
    """

    def __init__(self, c1):
        """Initializes AconC with learnable parameters p1, p2, and beta for channel-wise activation control."""
        super().__init__()
        self.p1 = nn.Parameter(torch.randn(1, c1, 1, 1))
        self.p2 = nn.Parameter(torch.randn(1, c1, 1, 1))
        self.beta = nn.Parameter(torch.ones(1, c1, 1, 1))

    def forward(self, x):
        """Applies AconC activation function with learnable parameters for channel-wise control on input tensor x."""
        dpx = (self.p1 - self.p2) * x
        return dpx * torch.sigmoid(self.beta * dpx) + self.p2 * x


class MetaAconC(nn.Module):
    """
    ACON activation (activate or not) function.

    AconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is a learnable parameter
    See "Activate or Not: Learning Customized Activation" https://arxiv.org/pdf/2009.04759.pdf.
    """

    def __init__(self, c1, k=1, s=1, r=16):
        """Initializes MetaAconC with params: channel_in (c1), kernel size (k=1), stride (s=1), reduction (r=16)."""
        super().__init__()
        c2 = max(r, c1 // r)
        self.p1 = nn.Parameter(torch.randn(1, c1, 1, 1))
        self.p2 = nn.Parameter(torch.randn(1, c1, 1, 1))
        self.fc1 = nn.Conv2d(c1, c2, k, s, bias=True)
        self.fc2 = nn.Conv2d(c2, c1, k, s, bias=True)
        # self.bn1 = nn.BatchNorm2d(c2)
        # self.bn2 = nn.BatchNorm2d(c1)

    def forward(self, x):
        """Applies a forward pass transforming input `x` using learnable parameters and sigmoid activation."""
        y = x.mean(dim=2, keepdims=True).mean(dim=3, keepdims=True)
        # batch-size 1 bug/instabilities https://github.com/ultralytics/yolov5/issues/2891
        # beta = torch.sigmoid(self.bn2(self.fc2(self.bn1(self.fc1(y)))))  # bug/unstable
        beta = torch.sigmoid(self.fc2(self.fc1(y)))  # bug patch BN layers removed
        dpx = (self.p1 - self.p2) * x
        return dpx * torch.sigmoid(beta * dpx) + self.p2 * x