from transformers import ModernBertModel, ModernBertPreTrainedModel from transformers.modeling_outputs import SequenceClassifierOutput from torch import nn import torch from train_utils import SentimentWeightedLoss, SentimentFocalLoss import torch.nn.functional as F from classifiers import ClassifierHead, ConcatClassifierHead class ModernBertForSentiment(ModernBertPreTrainedModel): """ModernBERT encoder with a dynamically configurable classification head and pooling strategy.""" def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.bert = ModernBertModel(config) # Base BERT model, config may have output_hidden_states=True # Store pooling strategy from config self.pooling_strategy = getattr(config, 'pooling_strategy', 'mean') self.num_weighted_layers = getattr(config, 'num_weighted_layers', 4) if self.pooling_strategy in ['weighted_layer', 'cls_weighted_concat'] and not config.output_hidden_states: # This check is more of an assertion; train.py should set output_hidden_states=True raise ValueError( "output_hidden_states must be True in BertConfig for weighted_layer pooling." ) # Initialize weights for weighted layer pooling if self.pooling_strategy in ['weighted_layer', 'cls_weighted_concat']: # num_weighted_layers specifies how many *top* layers of BERT to use. # If num_weighted_layers is e.g. 4, we use the last 4 layers. self.layer_weights = nn.Parameter(torch.ones(self.num_weighted_layers) / self.num_weighted_layers) # Determine classifier input size and choose head classifier_input_size = config.hidden_size if self.pooling_strategy in ['cls_mean_concat', 'cls_weighted_concat']: classifier_input_size = config.hidden_size * 2 # Dropout for features fed into the classifier head classifier_dropout_prob = ( config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob ) self.features_dropout = nn.Dropout(classifier_dropout_prob) # Select the appropriate classifier head based on input feature dimension if classifier_input_size == config.hidden_size: self.classifier = ClassifierHead( hidden_size=config.hidden_size, # input_size for ClassifierHead is just hidden_size num_labels=config.num_labels, dropout_prob=classifier_dropout_prob ) elif classifier_input_size == config.hidden_size * 2: self.classifier = ConcatClassifierHead( input_size=config.hidden_size * 2, hidden_size=config.hidden_size, # Internal hidden size of the head num_labels=config.num_labels, dropout_prob=classifier_dropout_prob ) else: # This case should ideally not be reached with current strategies raise ValueError(f"Unexpected classifier_input_size: {classifier_input_size}") # Initialize loss function based on config loss_config = getattr(config, 'loss_function', {'name': 'SentimentWeightedLoss', 'params': {}}) loss_name = loss_config.get('name', 'SentimentWeightedLoss') loss_params = loss_config.get('params', {}) if loss_name == "SentimentWeightedLoss": self.loss_fct = SentimentWeightedLoss() # SentimentWeightedLoss takes no arguments elif loss_name == "SentimentFocalLoss": # Ensure only relevant params are passed, or that loss_params is structured correctly for SentimentFocalLoss # For SentimentFocalLoss, expected params are 'gamma_focal' and 'label_smoothing_epsilon' self.loss_fct = SentimentFocalLoss(**loss_params) else: raise ValueError(f"Unsupported loss function: {loss_name}") self.post_init() # Initialize weights and apply final processing def _mean_pool(self, last_hidden_state, attention_mask): if attention_mask is None: attention_mask = torch.ones_like(last_hidden_state[:, :, 0]) # Assuming first dim of last hidden state is token ids input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float() sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1) sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) return sum_embeddings / sum_mask def _weighted_layer_pool(self, all_hidden_states): # all_hidden_states includes embeddings + output of each layer. # We want the outputs of the last num_weighted_layers. # Example: 12 layers -> all_hidden_states have 13 items (embeddings + 12 layers) # num_weighted_layers = 4 -> use layers 9, 10, 11, 12 (indices -4, -3, -2, -1) layers_to_weigh = torch.stack(all_hidden_states[-self.num_weighted_layers:], dim=0) # layers_to_weigh shape: (num_weighted_layers, batch_size, sequence_length, hidden_size) # Normalize weights to sum to 1 (softmax or simple division) normalized_weights = F.softmax(self.layer_weights, dim=-1) # Weighted sum across layers # Reshape weights for broadcasting: (num_weighted_layers, 1, 1, 1) weighted_hidden_states = layers_to_weigh * normalized_weights.view(-1, 1, 1, 1) weighted_sum_hidden_states = torch.sum(weighted_hidden_states, dim=0) # weighted_sum_hidden_states shape: (batch_size, sequence_length, hidden_size) # Pool the result (e.g., take [CLS] token of this weighted sum) return weighted_sum_hidden_states[:, 0] # Return CLS token of the weighted sum def forward( self, input_ids=None, attention_mask=None, labels=None, lengths=None, return_dict=None, **kwargs ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict bert_outputs = self.bert( input_ids, attention_mask=attention_mask, return_dict=return_dict, output_hidden_states=self.config.output_hidden_states # Controlled by train.py ) last_hidden_state = bert_outputs[0] # Or bert_outputs.last_hidden_state pooled_features = None if self.pooling_strategy == 'cls': pooled_features = last_hidden_state[:, 0] # CLS token elif self.pooling_strategy == 'mean': pooled_features = self._mean_pool(last_hidden_state, attention_mask) elif self.pooling_strategy == 'cls_mean_concat': cls_output = last_hidden_state[:, 0] mean_output = self._mean_pool(last_hidden_state, attention_mask) pooled_features = torch.cat((cls_output, mean_output), dim=1) elif self.pooling_strategy == 'weighted_layer': if not self.config.output_hidden_states or bert_outputs.hidden_states is None: raise ValueError("Weighted layer pooling requires output_hidden_states=True and hidden_states in BERT output.") all_hidden_states = bert_outputs.hidden_states pooled_features = self._weighted_layer_pool(all_hidden_states) elif self.pooling_strategy == 'cls_weighted_concat': if not self.config.output_hidden_states or bert_outputs.hidden_states is None: raise ValueError("Weighted layer pooling requires output_hidden_states=True and hidden_states in BERT output.") cls_output = last_hidden_state[:, 0] all_hidden_states = bert_outputs.hidden_states weighted_output = self._weighted_layer_pool(all_hidden_states) pooled_features = torch.cat((cls_output, weighted_output), dim=1) else: raise ValueError(f"Unknown pooling_strategy: {self.pooling_strategy}") pooled_features = self.features_dropout(pooled_features) logits = self.classifier(pooled_features) loss = None if labels is not None: if lengths is None: raise ValueError("lengths must be provided when labels are specified for loss calculation.") loss = self.loss_fct(logits.squeeze(-1), labels, lengths) if not return_dict: # Ensure 'outputs' from BERT is appropriately handled. If it's a tuple: bert_model_outputs = bert_outputs[1:] if isinstance(bert_outputs, tuple) else (bert_outputs.hidden_states, bert_outputs.attentions) output = (logits,) + bert_model_outputs return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=bert_outputs.hidden_states, attentions=bert_outputs.attentions, )