Spaces:
Sleeping
Sleeping
File size: 8,952 Bytes
472f1d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
from transformers import ModernBertModel, ModernBertPreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
from torch import nn
import torch
from train_utils import SentimentWeightedLoss, SentimentFocalLoss
import torch.nn.functional as F
from classifiers import ClassifierHead, ConcatClassifierHead
class ModernBertForSentiment(ModernBertPreTrainedModel):
"""ModernBERT encoder with a dynamically configurable classification head and pooling strategy."""
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.bert = ModernBertModel(config) # Base BERT model, config may have output_hidden_states=True
# Store pooling strategy from config
self.pooling_strategy = getattr(config, 'pooling_strategy', 'mean')
self.num_weighted_layers = getattr(config, 'num_weighted_layers', 4)
if self.pooling_strategy in ['weighted_layer', 'cls_weighted_concat'] and not config.output_hidden_states:
# This check is more of an assertion; train.py should set output_hidden_states=True
raise ValueError(
"output_hidden_states must be True in BertConfig for weighted_layer pooling."
)
# Initialize weights for weighted layer pooling
if self.pooling_strategy in ['weighted_layer', 'cls_weighted_concat']:
# num_weighted_layers specifies how many *top* layers of BERT to use.
# If num_weighted_layers is e.g. 4, we use the last 4 layers.
self.layer_weights = nn.Parameter(torch.ones(self.num_weighted_layers) / self.num_weighted_layers)
# Determine classifier input size and choose head
classifier_input_size = config.hidden_size
if self.pooling_strategy in ['cls_mean_concat', 'cls_weighted_concat']:
classifier_input_size = config.hidden_size * 2
# Dropout for features fed into the classifier head
classifier_dropout_prob = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.features_dropout = nn.Dropout(classifier_dropout_prob)
# Select the appropriate classifier head based on input feature dimension
if classifier_input_size == config.hidden_size:
self.classifier = ClassifierHead(
hidden_size=config.hidden_size, # input_size for ClassifierHead is just hidden_size
num_labels=config.num_labels,
dropout_prob=classifier_dropout_prob
)
elif classifier_input_size == config.hidden_size * 2:
self.classifier = ConcatClassifierHead(
input_size=config.hidden_size * 2,
hidden_size=config.hidden_size, # Internal hidden size of the head
num_labels=config.num_labels,
dropout_prob=classifier_dropout_prob
)
else:
# This case should ideally not be reached with current strategies
raise ValueError(f"Unexpected classifier_input_size: {classifier_input_size}")
# Initialize loss function based on config
loss_config = getattr(config, 'loss_function', {'name': 'SentimentWeightedLoss', 'params': {}})
loss_name = loss_config.get('name', 'SentimentWeightedLoss')
loss_params = loss_config.get('params', {})
if loss_name == "SentimentWeightedLoss":
self.loss_fct = SentimentWeightedLoss() # SentimentWeightedLoss takes no arguments
elif loss_name == "SentimentFocalLoss":
# Ensure only relevant params are passed, or that loss_params is structured correctly for SentimentFocalLoss
# For SentimentFocalLoss, expected params are 'gamma_focal' and 'label_smoothing_epsilon'
self.loss_fct = SentimentFocalLoss(**loss_params)
else:
raise ValueError(f"Unsupported loss function: {loss_name}")
self.post_init() # Initialize weights and apply final processing
def _mean_pool(self, last_hidden_state, attention_mask):
if attention_mask is None:
attention_mask = torch.ones_like(last_hidden_state[:, :, 0]) # Assuming first dim of last hidden state is token ids
input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sum_embeddings / sum_mask
def _weighted_layer_pool(self, all_hidden_states):
# all_hidden_states includes embeddings + output of each layer.
# We want the outputs of the last num_weighted_layers.
# Example: 12 layers -> all_hidden_states have 13 items (embeddings + 12 layers)
# num_weighted_layers = 4 -> use layers 9, 10, 11, 12 (indices -4, -3, -2, -1)
layers_to_weigh = torch.stack(all_hidden_states[-self.num_weighted_layers:], dim=0)
# layers_to_weigh shape: (num_weighted_layers, batch_size, sequence_length, hidden_size)
# Normalize weights to sum to 1 (softmax or simple division)
normalized_weights = F.softmax(self.layer_weights, dim=-1)
# Weighted sum across layers
# Reshape weights for broadcasting: (num_weighted_layers, 1, 1, 1)
weighted_hidden_states = layers_to_weigh * normalized_weights.view(-1, 1, 1, 1)
weighted_sum_hidden_states = torch.sum(weighted_hidden_states, dim=0)
# weighted_sum_hidden_states shape: (batch_size, sequence_length, hidden_size)
# Pool the result (e.g., take [CLS] token of this weighted sum)
return weighted_sum_hidden_states[:, 0] # Return CLS token of the weighted sum
def forward(
self,
input_ids=None,
attention_mask=None,
labels=None,
lengths=None,
return_dict=None,
**kwargs
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
bert_outputs = self.bert(
input_ids,
attention_mask=attention_mask,
return_dict=return_dict,
output_hidden_states=self.config.output_hidden_states # Controlled by train.py
)
last_hidden_state = bert_outputs[0] # Or bert_outputs.last_hidden_state
pooled_features = None
if self.pooling_strategy == 'cls':
pooled_features = last_hidden_state[:, 0] # CLS token
elif self.pooling_strategy == 'mean':
pooled_features = self._mean_pool(last_hidden_state, attention_mask)
elif self.pooling_strategy == 'cls_mean_concat':
cls_output = last_hidden_state[:, 0]
mean_output = self._mean_pool(last_hidden_state, attention_mask)
pooled_features = torch.cat((cls_output, mean_output), dim=1)
elif self.pooling_strategy == 'weighted_layer':
if not self.config.output_hidden_states or bert_outputs.hidden_states is None:
raise ValueError("Weighted layer pooling requires output_hidden_states=True and hidden_states in BERT output.")
all_hidden_states = bert_outputs.hidden_states
pooled_features = self._weighted_layer_pool(all_hidden_states)
elif self.pooling_strategy == 'cls_weighted_concat':
if not self.config.output_hidden_states or bert_outputs.hidden_states is None:
raise ValueError("Weighted layer pooling requires output_hidden_states=True and hidden_states in BERT output.")
cls_output = last_hidden_state[:, 0]
all_hidden_states = bert_outputs.hidden_states
weighted_output = self._weighted_layer_pool(all_hidden_states)
pooled_features = torch.cat((cls_output, weighted_output), dim=1)
else:
raise ValueError(f"Unknown pooling_strategy: {self.pooling_strategy}")
pooled_features = self.features_dropout(pooled_features)
logits = self.classifier(pooled_features)
loss = None
if labels is not None:
if lengths is None:
raise ValueError("lengths must be provided when labels are specified for loss calculation.")
loss = self.loss_fct(logits.squeeze(-1), labels, lengths)
if not return_dict:
# Ensure 'outputs' from BERT is appropriately handled. If it's a tuple:
bert_model_outputs = bert_outputs[1:] if isinstance(bert_outputs, tuple) else (bert_outputs.hidden_states, bert_outputs.attentions)
output = (logits,) + bert_model_outputs
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=bert_outputs.hidden_states,
attentions=bert_outputs.attentions,
) |