imdb-sentiment-demo / models.py
voxmenthe's picture
add full app and model initial test
472f1d2
from transformers import ModernBertModel, ModernBertPreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
from torch import nn
import torch
from train_utils import SentimentWeightedLoss, SentimentFocalLoss
import torch.nn.functional as F
from classifiers import ClassifierHead, ConcatClassifierHead
class ModernBertForSentiment(ModernBertPreTrainedModel):
"""ModernBERT encoder with a dynamically configurable classification head and pooling strategy."""
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.bert = ModernBertModel(config) # Base BERT model, config may have output_hidden_states=True
# Store pooling strategy from config
self.pooling_strategy = getattr(config, 'pooling_strategy', 'mean')
self.num_weighted_layers = getattr(config, 'num_weighted_layers', 4)
if self.pooling_strategy in ['weighted_layer', 'cls_weighted_concat'] and not config.output_hidden_states:
# This check is more of an assertion; train.py should set output_hidden_states=True
raise ValueError(
"output_hidden_states must be True in BertConfig for weighted_layer pooling."
)
# Initialize weights for weighted layer pooling
if self.pooling_strategy in ['weighted_layer', 'cls_weighted_concat']:
# num_weighted_layers specifies how many *top* layers of BERT to use.
# If num_weighted_layers is e.g. 4, we use the last 4 layers.
self.layer_weights = nn.Parameter(torch.ones(self.num_weighted_layers) / self.num_weighted_layers)
# Determine classifier input size and choose head
classifier_input_size = config.hidden_size
if self.pooling_strategy in ['cls_mean_concat', 'cls_weighted_concat']:
classifier_input_size = config.hidden_size * 2
# Dropout for features fed into the classifier head
classifier_dropout_prob = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.features_dropout = nn.Dropout(classifier_dropout_prob)
# Select the appropriate classifier head based on input feature dimension
if classifier_input_size == config.hidden_size:
self.classifier = ClassifierHead(
hidden_size=config.hidden_size, # input_size for ClassifierHead is just hidden_size
num_labels=config.num_labels,
dropout_prob=classifier_dropout_prob
)
elif classifier_input_size == config.hidden_size * 2:
self.classifier = ConcatClassifierHead(
input_size=config.hidden_size * 2,
hidden_size=config.hidden_size, # Internal hidden size of the head
num_labels=config.num_labels,
dropout_prob=classifier_dropout_prob
)
else:
# This case should ideally not be reached with current strategies
raise ValueError(f"Unexpected classifier_input_size: {classifier_input_size}")
# Initialize loss function based on config
loss_config = getattr(config, 'loss_function', {'name': 'SentimentWeightedLoss', 'params': {}})
loss_name = loss_config.get('name', 'SentimentWeightedLoss')
loss_params = loss_config.get('params', {})
if loss_name == "SentimentWeightedLoss":
self.loss_fct = SentimentWeightedLoss() # SentimentWeightedLoss takes no arguments
elif loss_name == "SentimentFocalLoss":
# Ensure only relevant params are passed, or that loss_params is structured correctly for SentimentFocalLoss
# For SentimentFocalLoss, expected params are 'gamma_focal' and 'label_smoothing_epsilon'
self.loss_fct = SentimentFocalLoss(**loss_params)
else:
raise ValueError(f"Unsupported loss function: {loss_name}")
self.post_init() # Initialize weights and apply final processing
def _mean_pool(self, last_hidden_state, attention_mask):
if attention_mask is None:
attention_mask = torch.ones_like(last_hidden_state[:, :, 0]) # Assuming first dim of last hidden state is token ids
input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sum_embeddings / sum_mask
def _weighted_layer_pool(self, all_hidden_states):
# all_hidden_states includes embeddings + output of each layer.
# We want the outputs of the last num_weighted_layers.
# Example: 12 layers -> all_hidden_states have 13 items (embeddings + 12 layers)
# num_weighted_layers = 4 -> use layers 9, 10, 11, 12 (indices -4, -3, -2, -1)
layers_to_weigh = torch.stack(all_hidden_states[-self.num_weighted_layers:], dim=0)
# layers_to_weigh shape: (num_weighted_layers, batch_size, sequence_length, hidden_size)
# Normalize weights to sum to 1 (softmax or simple division)
normalized_weights = F.softmax(self.layer_weights, dim=-1)
# Weighted sum across layers
# Reshape weights for broadcasting: (num_weighted_layers, 1, 1, 1)
weighted_hidden_states = layers_to_weigh * normalized_weights.view(-1, 1, 1, 1)
weighted_sum_hidden_states = torch.sum(weighted_hidden_states, dim=0)
# weighted_sum_hidden_states shape: (batch_size, sequence_length, hidden_size)
# Pool the result (e.g., take [CLS] token of this weighted sum)
return weighted_sum_hidden_states[:, 0] # Return CLS token of the weighted sum
def forward(
self,
input_ids=None,
attention_mask=None,
labels=None,
lengths=None,
return_dict=None,
**kwargs
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
bert_outputs = self.bert(
input_ids,
attention_mask=attention_mask,
return_dict=return_dict,
output_hidden_states=self.config.output_hidden_states # Controlled by train.py
)
last_hidden_state = bert_outputs[0] # Or bert_outputs.last_hidden_state
pooled_features = None
if self.pooling_strategy == 'cls':
pooled_features = last_hidden_state[:, 0] # CLS token
elif self.pooling_strategy == 'mean':
pooled_features = self._mean_pool(last_hidden_state, attention_mask)
elif self.pooling_strategy == 'cls_mean_concat':
cls_output = last_hidden_state[:, 0]
mean_output = self._mean_pool(last_hidden_state, attention_mask)
pooled_features = torch.cat((cls_output, mean_output), dim=1)
elif self.pooling_strategy == 'weighted_layer':
if not self.config.output_hidden_states or bert_outputs.hidden_states is None:
raise ValueError("Weighted layer pooling requires output_hidden_states=True and hidden_states in BERT output.")
all_hidden_states = bert_outputs.hidden_states
pooled_features = self._weighted_layer_pool(all_hidden_states)
elif self.pooling_strategy == 'cls_weighted_concat':
if not self.config.output_hidden_states or bert_outputs.hidden_states is None:
raise ValueError("Weighted layer pooling requires output_hidden_states=True and hidden_states in BERT output.")
cls_output = last_hidden_state[:, 0]
all_hidden_states = bert_outputs.hidden_states
weighted_output = self._weighted_layer_pool(all_hidden_states)
pooled_features = torch.cat((cls_output, weighted_output), dim=1)
else:
raise ValueError(f"Unknown pooling_strategy: {self.pooling_strategy}")
pooled_features = self.features_dropout(pooled_features)
logits = self.classifier(pooled_features)
loss = None
if labels is not None:
if lengths is None:
raise ValueError("lengths must be provided when labels are specified for loss calculation.")
loss = self.loss_fct(logits.squeeze(-1), labels, lengths)
if not return_dict:
# Ensure 'outputs' from BERT is appropriately handled. If it's a tuple:
bert_model_outputs = bert_outputs[1:] if isinstance(bert_outputs, tuple) else (bert_outputs.hidden_states, bert_outputs.attentions)
output = (logits,) + bert_model_outputs
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=bert_outputs.hidden_states,
attentions=bert_outputs.attentions,
)