Spaces:
Sleeping
Sleeping
from torch import nn | |
import torch | |
class ClassifierHead(nn.Module): | |
"""Basically a fancy MLP: 3-layer classifier head with GELU, LayerNorm, and Skip Connections.""" | |
def __init__(self, hidden_size, num_labels, dropout_prob): | |
super().__init__() | |
# Layer 1 | |
self.dense1 = nn.Linear(hidden_size, hidden_size) | |
self.norm1 = nn.LayerNorm(hidden_size) | |
self.activation = nn.GELU() | |
self.dropout1 = nn.Dropout(dropout_prob) | |
# Layer 2 | |
self.dense2 = nn.Linear(hidden_size, hidden_size) | |
self.norm2 = nn.LayerNorm(hidden_size) | |
self.dropout2 = nn.Dropout(dropout_prob) | |
# Output Layer | |
self.out_proj = nn.Linear(hidden_size, num_labels) | |
def forward(self, features): | |
# Layer 1 | |
identity1 = features | |
x = self.norm1(features) | |
x = self.dense1(x) | |
x = self.activation(x) | |
x = self.dropout1(x) | |
x = x + identity1 # skip connection | |
# Layer 2 | |
identity2 = x | |
x = self.norm2(x) | |
x = self.dense2(x) | |
x = self.activation(x) | |
x = self.dropout2(x) | |
x = x + identity2 # skip connection | |
# Output Layer | |
logits = self.out_proj(x) | |
return logits | |
class ConcatClassifierHead(nn.Module): | |
""" | |
An enhanced classifier head designed for concatenated CLS + Mean Pooling input. | |
Includes an initial projection layer before the standard enhanced block. | |
""" | |
def __init__(self, input_size, hidden_size, num_labels, dropout_prob): | |
super().__init__() | |
# Initial projection from concatenated size (2*hidden) down to hidden_size | |
self.initial_projection = nn.Linear(input_size, hidden_size) | |
self.initial_norm = nn.LayerNorm(hidden_size) # Norm after projection | |
self.initial_activation = nn.GELU() | |
self.initial_dropout = nn.Dropout(dropout_prob) | |
# Layer 1 | |
self.dense1 = nn.Linear(hidden_size, hidden_size) | |
self.norm1 = nn.LayerNorm(hidden_size) | |
self.activation = nn.GELU() | |
self.dropout1 = nn.Dropout(dropout_prob) | |
# Layer 2 | |
self.dense2 = nn.Linear(hidden_size, hidden_size) | |
self.norm2 = nn.LayerNorm(hidden_size) | |
self.dropout2 = nn.Dropout(dropout_prob) | |
# Output Layer | |
self.out_proj = nn.Linear(hidden_size, num_labels) | |
def forward(self, features): | |
# Initial Projection Step | |
x = self.initial_projection(features) | |
x = self.initial_norm(x) | |
x = self.initial_activation(x) | |
x = self.initial_dropout(x) | |
# x should now be of shape (batch_size, hidden_size) | |
# Layer 1 + Skip | |
identity1 = x # Skip connection starts after initial projection | |
x_res = self.norm1(x) | |
x_res = self.dense1(x_res) | |
x_res = self.activation(x_res) | |
x_res = self.dropout1(x_res) | |
x = x + x_res # skip connection | |
# Layer 2 + Skip | |
identity2 = x | |
x_res = self.norm2(x) | |
x_res = self.dense2(x_res) | |
x_res = self.activation(x_res) | |
x_res = self.dropout2(x_res) | |
x = x + x_res # skip connection | |
# Output Layer | |
logits = self.out_proj(x) | |
return logits | |
# ExpansionClassifierHead currently not used | |
class ExpansionClassifierHead(nn.Module): | |
""" | |
A classifier head using FFN-style expansion (input -> 4*hidden -> hidden -> labels). | |
Takes concatenated CLS + Mean Pooled features as input. | |
""" | |
def __init__(self, input_size, hidden_size, num_labels, dropout_prob): | |
super().__init__() | |
intermediate_size = hidden_size * 4 # FFN expansion factor | |
# Layer 1 (Expansion) | |
self.norm1 = nn.LayerNorm(input_size) | |
self.dense1 = nn.Linear(input_size, intermediate_size) | |
self.activation = nn.GELU() | |
self.dropout1 = nn.Dropout(dropout_prob) | |
# Layer 2 (Projection back down) | |
self.norm2 = nn.LayerNorm(intermediate_size) | |
self.dense2 = nn.Linear(intermediate_size, hidden_size) | |
# Activation and Dropout applied after projection | |
self.dropout2 = nn.Dropout(dropout_prob) | |
# Output Layer | |
self.out_proj = nn.Linear(hidden_size, num_labels) | |
def forward(self, features): | |
# Layer 1 | |
x = self.norm1(features) | |
x = self.dense1(x) | |
x = self.activation(x) | |
x = self.dropout1(x) | |
# Layer 2 | |
x = self.norm2(x) | |
x = self.dense2(x) | |
x = self.activation(x) | |
x = self.dropout2(x) | |
# Output Layer | |
logits = self.out_proj(x) | |
return logits | |