File size: 7,772 Bytes
e85fecb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
"""

Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)

Copyright(c) 2023 lyuwenyu. All Rights Reserved.

"""

import torch
import torch.distributed
import torch.nn.functional as F
import torchvision

from ...core import register
from ...misc import box_ops, dist_utils


@register()
class DetCriterion(torch.nn.Module):
    """Default Detection Criterion"""

    __share__ = ["num_classes"]
    __inject__ = ["matcher"]

    def __init__(

        self,

        losses,

        weight_dict,

        num_classes=80,

        alpha=0.75,

        gamma=2.0,

        box_fmt="cxcywh",

        matcher=None,

    ):
        """

        Args:

            losses (list[str]): requested losses, support ['boxes', 'vfl', 'focal']

            weight_dict (dict[str, float)]: corresponding losses weight, including

                ['loss_bbox', 'loss_giou', 'loss_vfl', 'loss_focal']

            box_fmt (str): in box format, 'cxcywh' or 'xyxy'

            matcher (Matcher): matcher used to match source to target

        """
        super().__init__()
        self.losses = losses
        self.weight_dict = weight_dict
        self.alpha = alpha
        self.gamma = gamma
        self.num_classes = num_classes
        self.box_fmt = box_fmt
        assert matcher is not None, ""
        self.matcher = matcher

    def forward(self, outputs, targets, **kwargs):
        """

        Args:

            outputs: Dict[Tensor], 'pred_boxes', 'pred_logits', 'meta'.

            targets, List[Dict[str, Tensor]], len(targets) == batch_size.

            kwargs, store other information such as current epoch id.

        Return:

            losses, Dict[str, Tensor]

        """
        matched = self.matcher(outputs, targets)
        values = matched["values"]
        indices = matched["indices"]
        num_boxes = self._get_positive_nums(indices)

        # Compute all the requested losses
        losses = {}
        for loss in self.losses:
            l_dict = self.get_loss(loss, outputs, targets, indices, num_boxes)
            l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
            losses.update(l_dict)
        return losses

    def _get_src_permutation_idx(self, indices):
        # permute predictions following indices
        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
        src_idx = torch.cat([src for (src, _) in indices])
        return batch_idx, src_idx

    def _get_tgt_permutation_idx(self, indices):
        # permute targets following indices
        batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
        tgt_idx = torch.cat([tgt for (_, tgt) in indices])
        return batch_idx, tgt_idx

    def _get_positive_nums(self, indices):
        # number of positive samples
        num_pos = sum(len(i) for (i, _) in indices)
        num_pos = torch.as_tensor([num_pos], dtype=torch.float32, device=indices[0][0].device)
        if dist_utils.is_dist_available_and_initialized():
            torch.distributed.all_reduce(num_pos)
        num_pos = torch.clamp(num_pos / dist_utils.get_world_size(), min=1).item()
        return num_pos

    def loss_labels_focal(self, outputs, targets, indices, num_boxes):
        assert "pred_logits" in outputs
        src_logits = outputs["pred_logits"]

        idx = self._get_src_permutation_idx(indices)
        target_classes_o = torch.cat([t["labels"][j] for t, (_, j) in zip(targets, indices)])
        target_classes = torch.full(
            src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
        )
        target_classes[idx] = target_classes_o

        target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1].to(
            src_logits.dtype
        )
        loss = torchvision.ops.sigmoid_focal_loss(
            src_logits, target, self.alpha, self.gamma, reduction="none"
        )
        loss = loss.sum() / num_boxes
        return {"loss_focal": loss}

    def loss_labels_vfl(self, outputs, targets, indices, num_boxes):
        assert "pred_boxes" in outputs
        idx = self._get_src_permutation_idx(indices)

        src_boxes = outputs["pred_boxes"][idx]
        target_boxes = torch.cat([t["boxes"][j] for t, (_, j) in zip(targets, indices)], dim=0)

        src_boxes = torchvision.ops.box_convert(src_boxes, in_fmt=self.box_fmt, out_fmt="xyxy")
        target_boxes = torchvision.ops.box_convert(
            target_boxes, in_fmt=self.box_fmt, out_fmt="xyxy"
        )
        iou, _ = box_ops.elementwise_box_iou(src_boxes.detach(), target_boxes)

        src_logits: torch.Tensor = outputs["pred_logits"]
        target_classes_o = torch.cat([t["labels"][j] for t, (_, j) in zip(targets, indices)])
        target_classes = torch.full(
            src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
        )
        target_classes[idx] = target_classes_o
        target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1]

        target_score_o = torch.zeros_like(target_classes, dtype=src_logits.dtype)
        target_score_o[idx] = iou.to(src_logits.dtype)
        target_score = target_score_o.unsqueeze(-1) * target

        src_score = F.sigmoid(src_logits.detach())
        weight = self.alpha * src_score.pow(self.gamma) * (1 - target) + target_score

        loss = F.binary_cross_entropy_with_logits(
            src_logits, target_score, weight=weight, reduction="none"
        )
        loss = loss.sum() / num_boxes
        return {"loss_vfl": loss}

    def loss_boxes(self, outputs, targets, indices, num_boxes):
        assert "pred_boxes" in outputs
        idx = self._get_src_permutation_idx(indices)
        src_boxes = outputs["pred_boxes"][idx]
        target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)

        losses = {}
        loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none")
        losses["loss_bbox"] = loss_bbox.sum() / num_boxes

        src_boxes = torchvision.ops.box_convert(src_boxes, in_fmt=self.box_fmt, out_fmt="xyxy")
        target_boxes = torchvision.ops.box_convert(
            target_boxes, in_fmt=self.box_fmt, out_fmt="xyxy"
        )
        loss_giou = 1 - box_ops.elementwise_generalized_box_iou(src_boxes, target_boxes)
        losses["loss_giou"] = loss_giou.sum() / num_boxes
        return losses

    def loss_boxes_giou(self, outputs, targets, indices, num_boxes):
        assert "pred_boxes" in outputs
        idx = self._get_src_permutation_idx(indices)
        src_boxes = outputs["pred_boxes"][idx]
        target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)

        losses = {}
        src_boxes = torchvision.ops.box_convert(src_boxes, in_fmt=self.box_fmt, out_fmt="xyxy")
        target_boxes = torchvision.ops.box_convert(
            target_boxes, in_fmt=self.box_fmt, out_fmt="xyxy"
        )
        loss_giou = 1 - box_ops.elementwise_generalized_box_iou(src_boxes, target_boxes)
        losses["loss_giou"] = loss_giou.sum() / num_boxes
        return losses

    def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
        loss_map = {
            "boxes": self.loss_boxes,
            "giou": self.loss_boxes_giou,
            "vfl": self.loss_labels_vfl,
            "focal": self.loss_labels_focal,
        }
        assert loss in loss_map, f"do you really want to compute {loss} loss?"
        return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)