Spaces:
Sleeping
Sleeping
from transformers import ModernBertModel, ModernBertPreTrainedModel | |
from transformers.modeling_outputs import SequenceClassifierOutput | |
from torch import nn | |
import torch | |
from train_utils import SentimentWeightedLoss, SentimentFocalLoss | |
import torch.nn.functional as F | |
from classifiers import ClassifierHead, ConcatClassifierHead | |
class ModernBertForSentiment(ModernBertPreTrainedModel): | |
"""ModernBERT encoder with a dynamically configurable classification head and pooling strategy.""" | |
def __init__(self, config): | |
super().__init__(config) | |
self.num_labels = config.num_labels | |
self.bert = ModernBertModel(config) # Base BERT model, config may have output_hidden_states=True | |
# Store pooling strategy from config | |
self.pooling_strategy = getattr(config, 'pooling_strategy', 'mean') | |
self.num_weighted_layers = getattr(config, 'num_weighted_layers', 4) | |
if self.pooling_strategy in ['weighted_layer', 'cls_weighted_concat'] and not config.output_hidden_states: | |
# This check is more of an assertion; train.py should set output_hidden_states=True | |
raise ValueError( | |
"output_hidden_states must be True in BertConfig for weighted_layer pooling." | |
) | |
# Initialize weights for weighted layer pooling | |
if self.pooling_strategy in ['weighted_layer', 'cls_weighted_concat']: | |
# num_weighted_layers specifies how many *top* layers of BERT to use. | |
# If num_weighted_layers is e.g. 4, we use the last 4 layers. | |
self.layer_weights = nn.Parameter(torch.ones(self.num_weighted_layers) / self.num_weighted_layers) | |
# Determine classifier input size and choose head | |
classifier_input_size = config.hidden_size | |
if self.pooling_strategy in ['cls_mean_concat', 'cls_weighted_concat']: | |
classifier_input_size = config.hidden_size * 2 | |
# Dropout for features fed into the classifier head | |
classifier_dropout_prob = ( | |
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob | |
) | |
self.features_dropout = nn.Dropout(classifier_dropout_prob) | |
# Select the appropriate classifier head based on input feature dimension | |
if classifier_input_size == config.hidden_size: | |
self.classifier = ClassifierHead( | |
hidden_size=config.hidden_size, # input_size for ClassifierHead is just hidden_size | |
num_labels=config.num_labels, | |
dropout_prob=classifier_dropout_prob | |
) | |
elif classifier_input_size == config.hidden_size * 2: | |
self.classifier = ConcatClassifierHead( | |
input_size=config.hidden_size * 2, | |
hidden_size=config.hidden_size, # Internal hidden size of the head | |
num_labels=config.num_labels, | |
dropout_prob=classifier_dropout_prob | |
) | |
else: | |
# This case should ideally not be reached with current strategies | |
raise ValueError(f"Unexpected classifier_input_size: {classifier_input_size}") | |
# Initialize loss function based on config | |
loss_config = getattr(config, 'loss_function', {'name': 'SentimentWeightedLoss', 'params': {}}) | |
loss_name = loss_config.get('name', 'SentimentWeightedLoss') | |
loss_params = loss_config.get('params', {}) | |
if loss_name == "SentimentWeightedLoss": | |
self.loss_fct = SentimentWeightedLoss() # SentimentWeightedLoss takes no arguments | |
elif loss_name == "SentimentFocalLoss": | |
# Ensure only relevant params are passed, or that loss_params is structured correctly for SentimentFocalLoss | |
# For SentimentFocalLoss, expected params are 'gamma_focal' and 'label_smoothing_epsilon' | |
self.loss_fct = SentimentFocalLoss(**loss_params) | |
else: | |
raise ValueError(f"Unsupported loss function: {loss_name}") | |
self.post_init() # Initialize weights and apply final processing | |
def _mean_pool(self, last_hidden_state, attention_mask): | |
if attention_mask is None: | |
attention_mask = torch.ones_like(last_hidden_state[:, :, 0]) # Assuming first dim of last hidden state is token ids | |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float() | |
sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1) | |
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
return sum_embeddings / sum_mask | |
def _weighted_layer_pool(self, all_hidden_states): | |
# all_hidden_states includes embeddings + output of each layer. | |
# We want the outputs of the last num_weighted_layers. | |
# Example: 12 layers -> all_hidden_states have 13 items (embeddings + 12 layers) | |
# num_weighted_layers = 4 -> use layers 9, 10, 11, 12 (indices -4, -3, -2, -1) | |
layers_to_weigh = torch.stack(all_hidden_states[-self.num_weighted_layers:], dim=0) | |
# layers_to_weigh shape: (num_weighted_layers, batch_size, sequence_length, hidden_size) | |
# Normalize weights to sum to 1 (softmax or simple division) | |
normalized_weights = F.softmax(self.layer_weights, dim=-1) | |
# Weighted sum across layers | |
# Reshape weights for broadcasting: (num_weighted_layers, 1, 1, 1) | |
weighted_hidden_states = layers_to_weigh * normalized_weights.view(-1, 1, 1, 1) | |
weighted_sum_hidden_states = torch.sum(weighted_hidden_states, dim=0) | |
# weighted_sum_hidden_states shape: (batch_size, sequence_length, hidden_size) | |
# Pool the result (e.g., take [CLS] token of this weighted sum) | |
return weighted_sum_hidden_states[:, 0] # Return CLS token of the weighted sum | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
labels=None, | |
lengths=None, | |
return_dict=None, | |
**kwargs | |
): | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
bert_outputs = self.bert( | |
input_ids, | |
attention_mask=attention_mask, | |
return_dict=return_dict, | |
output_hidden_states=self.config.output_hidden_states # Controlled by train.py | |
) | |
last_hidden_state = bert_outputs[0] # Or bert_outputs.last_hidden_state | |
pooled_features = None | |
if self.pooling_strategy == 'cls': | |
pooled_features = last_hidden_state[:, 0] # CLS token | |
elif self.pooling_strategy == 'mean': | |
pooled_features = self._mean_pool(last_hidden_state, attention_mask) | |
elif self.pooling_strategy == 'cls_mean_concat': | |
cls_output = last_hidden_state[:, 0] | |
mean_output = self._mean_pool(last_hidden_state, attention_mask) | |
pooled_features = torch.cat((cls_output, mean_output), dim=1) | |
elif self.pooling_strategy == 'weighted_layer': | |
if not self.config.output_hidden_states or bert_outputs.hidden_states is None: | |
raise ValueError("Weighted layer pooling requires output_hidden_states=True and hidden_states in BERT output.") | |
all_hidden_states = bert_outputs.hidden_states | |
pooled_features = self._weighted_layer_pool(all_hidden_states) | |
elif self.pooling_strategy == 'cls_weighted_concat': | |
if not self.config.output_hidden_states or bert_outputs.hidden_states is None: | |
raise ValueError("Weighted layer pooling requires output_hidden_states=True and hidden_states in BERT output.") | |
cls_output = last_hidden_state[:, 0] | |
all_hidden_states = bert_outputs.hidden_states | |
weighted_output = self._weighted_layer_pool(all_hidden_states) | |
pooled_features = torch.cat((cls_output, weighted_output), dim=1) | |
else: | |
raise ValueError(f"Unknown pooling_strategy: {self.pooling_strategy}") | |
pooled_features = self.features_dropout(pooled_features) | |
logits = self.classifier(pooled_features) | |
loss = None | |
if labels is not None: | |
if lengths is None: | |
raise ValueError("lengths must be provided when labels are specified for loss calculation.") | |
loss = self.loss_fct(logits.squeeze(-1), labels, lengths) | |
if not return_dict: | |
# Ensure 'outputs' from BERT is appropriately handled. If it's a tuple: | |
bert_model_outputs = bert_outputs[1:] if isinstance(bert_outputs, tuple) else (bert_outputs.hidden_states, bert_outputs.attentions) | |
output = (logits,) + bert_model_outputs | |
return ((loss,) + output) if loss is not None else output | |
return SequenceClassifierOutput( | |
loss=loss, | |
logits=logits, | |
hidden_states=bert_outputs.hidden_states, | |
attentions=bert_outputs.attentions, | |
) |