File size: 5,080 Bytes
719d0db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
import torch.nn as nn
from utils.utils import batched_bincount
import torch.nn.functional as F


class GeneralCrossEntropy(nn.Module):
    def __init__(self, weight_type: str, beta : float = 0.99, is_sequential: bool = True):
        super().__init__()
        self.weight_type = weight_type
        self.beta = beta
        if weight_type == "seq_cbce":
            assert is_sequential == True
            self.loss_func = SeqCBCrossEntropy(beta=beta)
        elif weight_type == "cbce":
            self.loss_func = CBCrossEntropy(beta=beta, is_sequential=is_sequential)
        elif weight_type == "wce":
            self.loss_func = WeightedCrossEntropy(is_sequential=is_sequential)
        elif weight_type == "ce":
            self.loss_func = CrossEntropy(is_sequential=is_sequential)
        else:
            NotImplementedError

    def forward(self, 
                preds: torch.Tensor, 
                labels: torch.Tensor,
                pad_mask: torch.Tensor = None):
        return self.loss_func(preds, labels, pad_mask)


class SeqCBCrossEntropy(nn.Module):
    def __init__(self, beta : float = 0.99):
        super().__init__()
        self.beta = beta

    def forward(self,
                preds: torch.Tensor, 
                labels: torch.Tensor,
                pad_mask: torch.Tensor):
        """
        Sequential Class-alanced Cross Entropy Loss (Our proposal)

        Parameters
        -----------
        preds: torch.Tensor [batch_size, max_seq_length, num_classes]
        labels: torch.Tensor [batch_size, max_seq_length]
        pad_mask: torch.Tensor [batch_size, max_seq_length]

        Returns
        -------
        loss: torch.Tensor [1]
        """
        seq_length_batch = pad_mask.sum(-1) # [batch_size]
        seq_length_list = torch.unique(seq_length_batch) # [num_unique_seq_length]
        batch_size = preds.size(0)
        loss = 0
        for seq_length in seq_length_list:
            extracted_batch = (seq_length_batch == seq_length) # [batch_size]
            extracted_preds = preds[extracted_batch]   # [num_extracted_batch]
            extracted_labels = labels[extracted_batch] # [num_extracted_batch]
            extracted_batch_size = extracted_labels.size(0)
            bin = batched_bincount(extracted_labels.T, 1, extracted_preds.size(-1)) # [seq_length x num_classes]
            weight = (1 - self.beta) / (1 - self.beta**bin + 1e-8)
            for seq_no in range(seq_length.item()):
                loss += (extracted_batch_size / batch_size) * F.nll_loss(extracted_preds[:, seq_no], extracted_labels[:, seq_no], weight=weight[seq_no])
        return loss

class CBCrossEntropy(nn.Module):
    def __init__(self, beta : float = 0.99, is_sequential: bool = True):
        super().__init__()
        self.beta = beta
        self.is_sequential = is_sequential

    def forward(self,
                preds: torch.Tensor, 
                labels: torch.Tensor,
                pad_mask: torch.Tensor = None):
        if self.is_sequential:
            mask = pad_mask.view(-1)
            preds = preds.view(-1, preds.size(-1))
            bin = labels.view(-1)[mask].bincount()
            weight = (1 - self.beta) / (1 - self.beta**bin + 1e-8)
            loss = F.nll_loss(preds[mask], labels.view(-1)[mask], weight=weight)
        else:
            bincount = labels.view(-1).bincount()
            weight = (1 - self.beta) / (1 - self.beta**bincount + 1e-8)
            loss = F.nll_loss(preds, labels.squeeze(-1), weight=weight)
        return loss

class WeightedCrossEntropy(nn.Module):
    def __init__(self, is_sequential: bool = True, norm: str = "min"):
        super().__init__()
        self.is_sequential = is_sequential
        if norm == "min":
            self.norm = torch.min
        elif norm == "max":
            self.norm = torch.max
    def forward(self,
                preds: torch.Tensor, 
                labels: torch.Tensor,
                pad_mask: torch.Tensor = None):
        if self.is_sequential:
            mask = pad_mask.view(-1)
            preds = preds.view(-1, preds.size(-1))
            bin = labels.view(-1)[mask].bincount()
            weight = self.norm(bin) / (bin + 1e-8)
            loss = F.nll_loss(preds[mask], labels.view(-1)[mask], weight=weight)
        else:
            bincount = labels.view(-1).bincount()
            weight = self.norm(bin) / (bin + 1e-8)
            loss = F.nll_loss(preds, labels.squeeze(-1), weight=weight)
        return loss

class CrossEntropy(nn.Module):
    def __init__(self, is_sequential: bool = True):
        super().__init__()
        self.is_sequential = is_sequential

    def forward(self,
                preds: torch.Tensor, 
                labels: torch.Tensor,
                pad_mask: torch.Tensor = None):
        if self.is_sequential:
            mask = pad_mask.view(-1)
            preds = preds.view(-1, preds.size(-1))
            loss = F.nll_loss(preds[mask], labels.view(-1)[mask])
        else:
            loss = F.nll_loss(preds, labels.squeeze(-1))
        return loss