Spaces:
Sleeping
Sleeping
File size: 4,614 Bytes
472f1d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
from torch import nn
import torch
class ClassifierHead(nn.Module):
"""Basically a fancy MLP: 3-layer classifier head with GELU, LayerNorm, and Skip Connections."""
def __init__(self, hidden_size, num_labels, dropout_prob):
super().__init__()
# Layer 1
self.dense1 = nn.Linear(hidden_size, hidden_size)
self.norm1 = nn.LayerNorm(hidden_size)
self.activation = nn.GELU()
self.dropout1 = nn.Dropout(dropout_prob)
# Layer 2
self.dense2 = nn.Linear(hidden_size, hidden_size)
self.norm2 = nn.LayerNorm(hidden_size)
self.dropout2 = nn.Dropout(dropout_prob)
# Output Layer
self.out_proj = nn.Linear(hidden_size, num_labels)
def forward(self, features):
# Layer 1
identity1 = features
x = self.norm1(features)
x = self.dense1(x)
x = self.activation(x)
x = self.dropout1(x)
x = x + identity1 # skip connection
# Layer 2
identity2 = x
x = self.norm2(x)
x = self.dense2(x)
x = self.activation(x)
x = self.dropout2(x)
x = x + identity2 # skip connection
# Output Layer
logits = self.out_proj(x)
return logits
class ConcatClassifierHead(nn.Module):
"""
An enhanced classifier head designed for concatenated CLS + Mean Pooling input.
Includes an initial projection layer before the standard enhanced block.
"""
def __init__(self, input_size, hidden_size, num_labels, dropout_prob):
super().__init__()
# Initial projection from concatenated size (2*hidden) down to hidden_size
self.initial_projection = nn.Linear(input_size, hidden_size)
self.initial_norm = nn.LayerNorm(hidden_size) # Norm after projection
self.initial_activation = nn.GELU()
self.initial_dropout = nn.Dropout(dropout_prob)
# Layer 1
self.dense1 = nn.Linear(hidden_size, hidden_size)
self.norm1 = nn.LayerNorm(hidden_size)
self.activation = nn.GELU()
self.dropout1 = nn.Dropout(dropout_prob)
# Layer 2
self.dense2 = nn.Linear(hidden_size, hidden_size)
self.norm2 = nn.LayerNorm(hidden_size)
self.dropout2 = nn.Dropout(dropout_prob)
# Output Layer
self.out_proj = nn.Linear(hidden_size, num_labels)
def forward(self, features):
# Initial Projection Step
x = self.initial_projection(features)
x = self.initial_norm(x)
x = self.initial_activation(x)
x = self.initial_dropout(x)
# x should now be of shape (batch_size, hidden_size)
# Layer 1 + Skip
identity1 = x # Skip connection starts after initial projection
x_res = self.norm1(x)
x_res = self.dense1(x_res)
x_res = self.activation(x_res)
x_res = self.dropout1(x_res)
x = x + x_res # skip connection
# Layer 2 + Skip
identity2 = x
x_res = self.norm2(x)
x_res = self.dense2(x_res)
x_res = self.activation(x_res)
x_res = self.dropout2(x_res)
x = x + x_res # skip connection
# Output Layer
logits = self.out_proj(x)
return logits
# ExpansionClassifierHead currently not used
class ExpansionClassifierHead(nn.Module):
"""
A classifier head using FFN-style expansion (input -> 4*hidden -> hidden -> labels).
Takes concatenated CLS + Mean Pooled features as input.
"""
def __init__(self, input_size, hidden_size, num_labels, dropout_prob):
super().__init__()
intermediate_size = hidden_size * 4 # FFN expansion factor
# Layer 1 (Expansion)
self.norm1 = nn.LayerNorm(input_size)
self.dense1 = nn.Linear(input_size, intermediate_size)
self.activation = nn.GELU()
self.dropout1 = nn.Dropout(dropout_prob)
# Layer 2 (Projection back down)
self.norm2 = nn.LayerNorm(intermediate_size)
self.dense2 = nn.Linear(intermediate_size, hidden_size)
# Activation and Dropout applied after projection
self.dropout2 = nn.Dropout(dropout_prob)
# Output Layer
self.out_proj = nn.Linear(hidden_size, num_labels)
def forward(self, features):
# Layer 1
x = self.norm1(features)
x = self.dense1(x)
x = self.activation(x)
x = self.dropout1(x)
# Layer 2
x = self.norm2(x)
x = self.dense2(x)
x = self.activation(x)
x = self.dropout2(x)
# Output Layer
logits = self.out_proj(x)
return logits
|