File size: 4,614 Bytes
472f1d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from torch import nn
import torch


class ClassifierHead(nn.Module):
    """Basically a fancy MLP: 3-layer classifier head with GELU, LayerNorm, and Skip Connections."""
    def __init__(self, hidden_size, num_labels, dropout_prob):
        super().__init__()
        # Layer 1
        self.dense1 = nn.Linear(hidden_size, hidden_size)
        self.norm1 = nn.LayerNorm(hidden_size)
        self.activation = nn.GELU()
        self.dropout1 = nn.Dropout(dropout_prob)

        # Layer 2
        self.dense2 = nn.Linear(hidden_size, hidden_size)
        self.norm2 = nn.LayerNorm(hidden_size)
        self.dropout2 = nn.Dropout(dropout_prob)

        # Output Layer
        self.out_proj = nn.Linear(hidden_size, num_labels)

    def forward(self, features):
        # Layer 1
        identity1 = features
        x = self.norm1(features)
        x = self.dense1(x)
        x = self.activation(x)
        x = self.dropout1(x)
        x = x + identity1 # skip connection

        # Layer 2
        identity2 = x
        x = self.norm2(x)
        x = self.dense2(x)
        x = self.activation(x)
        x = self.dropout2(x)
        x = x + identity2 # skip connection

        # Output Layer
        logits = self.out_proj(x)
        return logits


class ConcatClassifierHead(nn.Module):
    """
    An enhanced classifier head designed for concatenated CLS + Mean Pooling input.
    Includes an initial projection layer before the standard enhanced block.
    """
    def __init__(self, input_size, hidden_size, num_labels, dropout_prob):
        super().__init__()
        # Initial projection from concatenated size (2*hidden) down to hidden_size
        self.initial_projection = nn.Linear(input_size, hidden_size)
        self.initial_norm = nn.LayerNorm(hidden_size) # Norm after projection
        self.initial_activation = nn.GELU()
        self.initial_dropout = nn.Dropout(dropout_prob)

        # Layer 1
        self.dense1 = nn.Linear(hidden_size, hidden_size)
        self.norm1 = nn.LayerNorm(hidden_size)
        self.activation = nn.GELU()
        self.dropout1 = nn.Dropout(dropout_prob)

        # Layer 2
        self.dense2 = nn.Linear(hidden_size, hidden_size)
        self.norm2 = nn.LayerNorm(hidden_size)
        self.dropout2 = nn.Dropout(dropout_prob)

        # Output Layer
        self.out_proj = nn.Linear(hidden_size, num_labels)

    def forward(self, features):
        # Initial Projection Step
        x = self.initial_projection(features)
        x = self.initial_norm(x)
        x = self.initial_activation(x)
        x = self.initial_dropout(x)
        # x should now be of shape (batch_size, hidden_size)

        # Layer 1 + Skip
        identity1 = x # Skip connection starts after initial projection
        x_res = self.norm1(x)
        x_res = self.dense1(x_res)
        x_res = self.activation(x_res)
        x_res = self.dropout1(x_res)
        x = x + x_res # skip connection

        # Layer 2 + Skip
        identity2 = x
        x_res = self.norm2(x)
        x_res = self.dense2(x_res)
        x_res = self.activation(x_res)
        x_res = self.dropout2(x_res)
        x = x + x_res # skip connection

        # Output Layer
        logits = self.out_proj(x)
        return logits


# ExpansionClassifierHead currently not used
class ExpansionClassifierHead(nn.Module):
    """
    A classifier head using FFN-style expansion (input -> 4*hidden -> hidden -> labels).
    Takes concatenated CLS + Mean Pooled features as input.
    """
    def __init__(self, input_size, hidden_size, num_labels, dropout_prob):
        super().__init__()
        intermediate_size = hidden_size * 4 # FFN expansion factor

        # Layer 1 (Expansion)
        self.norm1 = nn.LayerNorm(input_size)
        self.dense1 = nn.Linear(input_size, intermediate_size)
        self.activation = nn.GELU()
        self.dropout1 = nn.Dropout(dropout_prob)

        # Layer 2 (Projection back down)
        self.norm2 = nn.LayerNorm(intermediate_size)
        self.dense2 = nn.Linear(intermediate_size, hidden_size)
        # Activation and Dropout applied after projection
        self.dropout2 = nn.Dropout(dropout_prob)

        # Output Layer
        self.out_proj = nn.Linear(hidden_size, num_labels)

    def forward(self, features):
        # Layer 1
        x = self.norm1(features)
        x = self.dense1(x)
        x = self.activation(x)
        x = self.dropout1(x)

        # Layer 2
        x = self.norm2(x)
        x = self.dense2(x)
        x = self.activation(x)
        x = self.dropout2(x)

        # Output Layer
        logits = self.out_proj(x)
        return logits