File size: 1,812 Bytes
b6c45cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from functools import partial
import torch
from torch import nn


class Swish(nn.Module):
    def __init__(self):
        super(Swish, self).__init__()

    def forward(self, x):
        return x * torch.sigmoid(x)


def linear():
    return nn.Identity()


def relu():
    return nn.ReLU()


def prelu():
    return nn.PReLU()


def leaky_relu():
    return nn.LeakyReLU()


def sigmoid():
    return nn.Sigmoid()


def softmax(dim=None):
    return nn.Softmax(dim=dim)


def tanh():
    return nn.Tanh()


def gelu():
    return nn.GELU()


def swish():
    return Swish()


def register_activation(custom_act):
    """Register a custom activation, gettable with `activation.get`.



    Args:

        custom_act: Custom activation function to register.



    """
    if custom_act.__name__ in globals().keys() or custom_act.__name__.lower() in globals().keys():
        raise ValueError(f"Activation {custom_act.__name__} already exists. Choose another name.")
    globals().update({custom_act.__name__: custom_act})


def get(identifier):
    """Returns an activation function from a string. Returns its input if it

    is callable (already an activation for example).



    Args:

        identifier (str or Callable or None): the activation identifier.



    Returns:

        :class:`nn.Module` or None

    """
    if identifier is None:
        return None
    elif callable(identifier):
        return identifier
    elif isinstance(identifier, str):
        cls = globals().get(identifier)
        if cls is None:
            raise ValueError("Could not interpret activation identifier: " + str(identifier))
        return cls
    else:
        raise ValueError("Could not interpret activation identifier: " + str(identifier))