File size: 7,702 Bytes
472f1d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import math
from torch import nn
import torch
import torch.nn.functional as F


class SentimentWeightedLoss(nn.Module):
    """BCEWithLogits + dynamic weighting.

    We weight each sample by:
      • length_weight:  sqrt(num_tokens) / sqrt(max_tokens)
      • confidence_weight: |sigmoid(logits) - 0.5|  (higher confidence ⇒ larger weight)

    The two weights are combined multiplicatively then normalized.
    """

    def __init__(self):
        super().__init__()
        # Initialize BCE loss without reduction, since we're applying per-sample weights
        self.bce = nn.BCEWithLogitsLoss(reduction="none")
        self.min_len_weight_sqrt = 0.1  # Minimum length weight

    def forward(self, logits, targets, lengths):
        base_loss = self.bce(logits.view(-1), targets.float())  # shape [B]
    
        prob = torch.sigmoid(logits.view(-1))
        confidence_weight = (prob - 0.5).abs() * 2  # ∈ [0,1]

        if lengths.numel() == 0:
            # Handle empty batch: return 0.0 loss or mean of base_loss if it's also empty (becomes nan then)
            # If base_loss on empty input is empty tensor, mean is nan. So return 0.0 is safer.
            return torch.tensor(0.0, device=logits.device, requires_grad=logits.requires_grad)
        
        length_weight = torch.sqrt(lengths.float()) / math.sqrt(lengths.max().item())
        length_weight = length_weight.clamp(self.min_len_weight_sqrt, 1.0) # Clamp to avoid extreme weights

        weights = confidence_weight * length_weight
        weights = weights / (weights.mean() + 1e-8)  # normalize so E[w]=1
        return (base_loss * weights).mean()




class SentimentFocalLoss(nn.Module):
    """
    This loss function incorporates:
    1. Base BCEWithLogitsLoss.
    2. Label Smoothing.
    3. Focal Loss modulation to focus more on hard examples (can be reversed to focus on easy examples).
    4. Sample weighting based on review length.
    5. Sample weighting based on prediction confidence.

    The final loss for each sample is calculated roughly as:
    Loss_sample = FocalModulator(pt, gamma) * BCE(logits, smoothed_targets) * NormalizedExternalWeight
    NormalizedExternalWeight = (ConfidenceWeight * LengthWeight) / Mean(ConfidenceWeight * LengthWeight)
    """

    def __init__(self, gamma_focal: float = 0.1, label_smoothing_epsilon: float = 0.05):
        """
        Args:
            gamma_focal (float): Gamma parameter for Focal Loss.
                - If gamma_focal > 0 (e.g., 2.0), applies standard Focal Loss,
                  down-weighting easy examples (focus on hard examples).
                - If gamma_focal < 0 (e.g., -2.0), applies a reversed Focal Loss,
                  down-weighting hard examples (focus on easy examples by up-weighting pt).
                - If gamma_focal = 0, no Focal Loss modulation is applied.
            label_smoothing_epsilon (float): Epsilon for label smoothing. (0.0 <= epsilon < 1.0)
                - If 0.0, no label smoothing is applied. Converts hard labels (0, 1)
                  to soft labels (epsilon, 1-epsilon).
        """
        super().__init__()
        if not (0.0 <= label_smoothing_epsilon < 1.0):
            raise ValueError("label_smoothing_epsilon must be between 0.0 and <1.0.")
        
        self.gamma_focal = gamma_focal
        self.label_smoothing_epsilon = label_smoothing_epsilon
        # Initialize BCE loss without reduction, since we're applying per-sample weights
        self.bce_loss_no_reduction = nn.BCEWithLogitsLoss(reduction="none")

    def forward(self, logits: torch.Tensor, targets: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
        """
        Computes the custom loss.

        Args:
            logits (torch.Tensor): Raw logits from the model. Expected shape [B] or [B, 1].
            targets (torch.Tensor): Ground truth labels (0 or 1). Expected shape [B] or [B, 1].
            lengths (torch.Tensor): Number of tokens in each review. Expected shape [B].

        Returns:
            torch.Tensor: The computed scalar loss.
        """
        B = logits.size(0)
        if B == 0: # Handle empty batch case
            return torch.tensor(0.0, device=logits.device, requires_grad=True)

        logits_flat = logits.view(-1)
        original_targets_flat = targets.view(-1).float() # Ensure targets are float

        # 1. Label Smoothing
        if self.label_smoothing_epsilon > 0:
            # Smooth 1 to (1 - epsilon), and 0 to epsilon
            targets_for_bce = original_targets_flat * (1.0 - self.label_smoothing_epsilon) + \
                              (1.0 - original_targets_flat) * self.label_smoothing_epsilon
        else:
            targets_for_bce = original_targets_flat

        # 2. Calculate Base BCE loss terms (using potentially smoothed targets)
        base_bce_loss_terms = self.bce_loss_no_reduction(logits_flat, targets_for_bce)

        # 3. Focal Loss Modulation Component
        # For the focal modulator, 'pt' is the probability assigned by the model to the *original* ground truth class.
        probs = torch.sigmoid(logits_flat)
        # pt: probability of the original true class
        pt = torch.where(original_targets_flat.bool(), probs, 1.0 - probs)

        focal_modulator = torch.ones_like(pt) # Default to 1 (no modulation if gamma_focal is 0)
        if self.gamma_focal > 0:  # Standard Focal Loss: (1-pt)^gamma. Focus on hard examples (pt is small).
            focal_modulator = (1.0 - pt + 1e-8).pow(self.gamma_focal) # Epsilon for stability if pt is 1
        elif self.gamma_focal < 0:  # Reversed Focal: (pt)^|gamma|. Focus on easy examples (pt is large).
            focal_modulator = (pt + 1e-8).pow(abs(self.gamma_focal)) # Epsilon for stability if pt is 0
        
        modulated_loss_terms = focal_modulator * base_bce_loss_terms

        # 4. Confidence Weighting (based on how far probability is from 0.5)
        # Uses the same `probs` calculated for focal `pt`.
        confidence_w = (probs - 0.5).abs() * 2.0  # Scales to range [0, 1]

        # 5. Length Weighting (longer reviews potentially weighted more)
        lengths_flat = lengths.view(-1).float()
        max_len_in_batch = lengths_flat.max().item()
        
        if max_len_in_batch == 0: # Edge case: if all reviews in batch have 0 length
            length_w = torch.ones_like(lengths_flat)
        else:
            # Normalize by sqrt of max length in the current batch. Add epsilon for stability.
            length_w = torch.sqrt(lengths_flat) / (math.sqrt(max_len_in_batch) + 1e-8)
            length_w = torch.clamp(length_w, 0.0, 1.0) # Ensure weights are capped at 1

        # 6. Combine External Weights (Confidence and Length)
        # These weights are applied ON TOP of the focal-modulated loss terms.
        external_weights = confidence_w * length_w
        
        # Normalize these combined external_weights so their mean is approximately 1.
        # This prevents the weighting scheme from drastically changing the overall loss magnitude.
        if external_weights.sum() > 1e-8: # Avoid division by zero if all weights are zero
             normalized_external_weights = external_weights / (external_weights.mean() + 1e-8)
        else: # If all external weights are zero, use ones to not nullify the loss.
             normalized_external_weights = torch.ones_like(external_weights)

        # 7. Apply Normalized External Weights to the (Focal) Modulated Loss Terms
        final_loss_terms_per_sample = modulated_loss_terms * normalized_external_weights
        
        # 8. Final Reduction: Mean of the per-sample losses
        loss = final_loss_terms_per_sample.mean()
        
        return loss