File size: 8,952 Bytes
472f1d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
from transformers import ModernBertModel, ModernBertPreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
from torch import nn
import torch
from train_utils import SentimentWeightedLoss, SentimentFocalLoss
import torch.nn.functional as F

from classifiers import ClassifierHead, ConcatClassifierHead


class ModernBertForSentiment(ModernBertPreTrainedModel):
    """ModernBERT encoder with a dynamically configurable classification head and pooling strategy."""

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.bert = ModernBertModel(config) # Base BERT model, config may have output_hidden_states=True

        # Store pooling strategy from config
        self.pooling_strategy = getattr(config, 'pooling_strategy', 'mean')
        self.num_weighted_layers = getattr(config, 'num_weighted_layers', 4)

        if self.pooling_strategy in ['weighted_layer', 'cls_weighted_concat'] and not config.output_hidden_states:
            # This check is more of an assertion; train.py should set output_hidden_states=True
            raise ValueError(
                "output_hidden_states must be True in BertConfig for weighted_layer pooling."
            )

        # Initialize weights for weighted layer pooling
        if self.pooling_strategy in ['weighted_layer', 'cls_weighted_concat']:
            # num_weighted_layers specifies how many *top* layers of BERT to use.
            # If num_weighted_layers is e.g. 4, we use the last 4 layers.
            self.layer_weights = nn.Parameter(torch.ones(self.num_weighted_layers) / self.num_weighted_layers)

        # Determine classifier input size and choose head
        classifier_input_size = config.hidden_size
        if self.pooling_strategy in ['cls_mean_concat', 'cls_weighted_concat']:
            classifier_input_size = config.hidden_size * 2
        
        # Dropout for features fed into the classifier head
        classifier_dropout_prob = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.features_dropout = nn.Dropout(classifier_dropout_prob)

        # Select the appropriate classifier head based on input feature dimension
        if classifier_input_size == config.hidden_size:
            self.classifier = ClassifierHead(
                hidden_size=config.hidden_size, # input_size for ClassifierHead is just hidden_size
                num_labels=config.num_labels,
                dropout_prob=classifier_dropout_prob
            )
        elif classifier_input_size == config.hidden_size * 2:
            self.classifier = ConcatClassifierHead(
                input_size=config.hidden_size * 2,
                hidden_size=config.hidden_size, # Internal hidden size of the head
                num_labels=config.num_labels,
                dropout_prob=classifier_dropout_prob
            )
        else:
            # This case should ideally not be reached with current strategies
            raise ValueError(f"Unexpected classifier_input_size: {classifier_input_size}")

        # Initialize loss function based on config
        loss_config = getattr(config, 'loss_function', {'name': 'SentimentWeightedLoss', 'params': {}})
        loss_name = loss_config.get('name', 'SentimentWeightedLoss')
        loss_params = loss_config.get('params', {})

        if loss_name == "SentimentWeightedLoss":
            self.loss_fct = SentimentWeightedLoss() # SentimentWeightedLoss takes no arguments
        elif loss_name == "SentimentFocalLoss":
            # Ensure only relevant params are passed, or that loss_params is structured correctly for SentimentFocalLoss
            # For SentimentFocalLoss, expected params are 'gamma_focal' and 'label_smoothing_epsilon'
            self.loss_fct = SentimentFocalLoss(**loss_params)
        else:
            raise ValueError(f"Unsupported loss function: {loss_name}")

        self.post_init() # Initialize weights and apply final processing

    def _mean_pool(self, last_hidden_state, attention_mask):
        if attention_mask is None:
            attention_mask = torch.ones_like(last_hidden_state[:, :, 0]) # Assuming first dim of last hidden state is token ids
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
        sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sum_embeddings / sum_mask

    def _weighted_layer_pool(self, all_hidden_states):
        # all_hidden_states includes embeddings + output of each layer.
        # We want the outputs of the last num_weighted_layers.
        # Example: 12 layers -> all_hidden_states have 13 items (embeddings + 12 layers)
        # num_weighted_layers = 4 -> use layers 9, 10, 11, 12 (indices -4, -3, -2, -1)
        layers_to_weigh = torch.stack(all_hidden_states[-self.num_weighted_layers:], dim=0)
        # layers_to_weigh shape: (num_weighted_layers, batch_size, sequence_length, hidden_size)
        
        # Normalize weights to sum to 1 (softmax or simple division)
        normalized_weights = F.softmax(self.layer_weights, dim=-1)
        
        # Weighted sum across layers
        # Reshape weights for broadcasting: (num_weighted_layers, 1, 1, 1)
        weighted_hidden_states = layers_to_weigh * normalized_weights.view(-1, 1, 1, 1)
        weighted_sum_hidden_states = torch.sum(weighted_hidden_states, dim=0)
        # weighted_sum_hidden_states shape: (batch_size, sequence_length, hidden_size)
        
        # Pool the result (e.g., take [CLS] token of this weighted sum)
        return weighted_sum_hidden_states[:, 0] # Return CLS token of the weighted sum

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        labels=None,
        lengths=None,
        return_dict=None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        bert_outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            return_dict=return_dict,
            output_hidden_states=self.config.output_hidden_states # Controlled by train.py
        )

        last_hidden_state = bert_outputs[0] # Or bert_outputs.last_hidden_state
        pooled_features = None

        if self.pooling_strategy == 'cls':
            pooled_features = last_hidden_state[:, 0] # CLS token
        elif self.pooling_strategy == 'mean':
            pooled_features = self._mean_pool(last_hidden_state, attention_mask)
        elif self.pooling_strategy == 'cls_mean_concat':
            cls_output = last_hidden_state[:, 0]
            mean_output = self._mean_pool(last_hidden_state, attention_mask)
            pooled_features = torch.cat((cls_output, mean_output), dim=1)
        elif self.pooling_strategy == 'weighted_layer':
            if not self.config.output_hidden_states or bert_outputs.hidden_states is None:
                raise ValueError("Weighted layer pooling requires output_hidden_states=True and hidden_states in BERT output.")
            all_hidden_states = bert_outputs.hidden_states
            pooled_features = self._weighted_layer_pool(all_hidden_states)
        elif self.pooling_strategy == 'cls_weighted_concat':
            if not self.config.output_hidden_states or bert_outputs.hidden_states is None:
                raise ValueError("Weighted layer pooling requires output_hidden_states=True and hidden_states in BERT output.")
            cls_output = last_hidden_state[:, 0]
            all_hidden_states = bert_outputs.hidden_states
            weighted_output = self._weighted_layer_pool(all_hidden_states)
            pooled_features = torch.cat((cls_output, weighted_output), dim=1)
        else:
            raise ValueError(f"Unknown pooling_strategy: {self.pooling_strategy}")

        pooled_features = self.features_dropout(pooled_features)
        logits = self.classifier(pooled_features)

        loss = None
        if labels is not None:
            if lengths is None:
                raise ValueError("lengths must be provided when labels are specified for loss calculation.")
            loss = self.loss_fct(logits.squeeze(-1), labels, lengths)

        if not return_dict:
            # Ensure 'outputs' from BERT is appropriately handled. If it's a tuple:            
            bert_model_outputs = bert_outputs[1:] if isinstance(bert_outputs, tuple) else (bert_outputs.hidden_states, bert_outputs.attentions)
            output = (logits,) + bert_model_outputs
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=bert_outputs.hidden_states,
            attentions=bert_outputs.attentions,
        )