File size: 10,903 Bytes
3b96cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
# Copyright (c) OpenMMLab. All rights reserved.
import copy

import torch
from mmcv.ops import batched_nms
from torch import Tensor, nn

from mmdet.structures.bbox import bbox_cxcywh_to_xyxy
from .deformable_detr_layers import DeformableDetrTransformerDecoder
from .utils import MLP, coordinate_to_encoding, inverse_sigmoid


class DDQTransformerDecoder(DeformableDetrTransformerDecoder):
    """Transformer decoder of DDQ."""

    def _init_layers(self) -> None:
        """Initialize encoder layers."""
        super()._init_layers()
        self.ref_point_head = MLP(self.embed_dims * 2, self.embed_dims,
                                  self.embed_dims, 2)
        self.norm = nn.LayerNorm(self.embed_dims)

    def select_distinct_queries(self, reference_points: Tensor, query: Tensor,
                                self_attn_mask: Tensor, layer_index):
        """Get updated `self_attn_mask` for distinct queries selection, it is
        used in self attention layers of decoder.

        Args:
            reference_points (Tensor): The input reference of decoder,
                has shape (bs, num_queries, 4) with the last dimension
                arranged as (cx, cy, w, h).
            query (Tensor): The input query of decoder, has shape
                (bs, num_queries, dims).
            self_attn_mask (Tensor): The input self attention mask of
                last decoder layer, has shape (bs, num_queries_total,
                num_queries_total).
            layer_index (int): Last decoder layer index, used to get
                classification score of last layer output, for
                distinct queries selection.

        Returns:
            Tensor: `self_attn_mask` used in self attention layers
                of decoder, has shape (bs, num_queries_total,
                num_queries_total).
        """
        num_imgs = len(reference_points)
        dis_start, num_dis = self.cache_dict['dis_query_info']
        # shape of self_attn_mask
        # (batch⋅num_heads, num_queries, embed_dims)
        dis_mask = self_attn_mask[:, dis_start:dis_start + num_dis,
                                  dis_start:dis_start + num_dis]
        # cls_branches from DDQDETRHead
        scores = self.cache_dict['cls_branches'][layer_index](
            query[:, dis_start:dis_start + num_dis]).sigmoid().max(-1).values
        proposals = reference_points[:, dis_start:dis_start + num_dis]
        proposals = bbox_cxcywh_to_xyxy(proposals)

        attn_mask_list = []
        for img_id in range(num_imgs):
            single_proposals = proposals[img_id]
            single_scores = scores[img_id]
            attn_mask = ~dis_mask[img_id * self.cache_dict['num_heads']][0]
            # distinct query inds in this layer
            ori_index = attn_mask.nonzero().view(-1)
            _, keep_idxs = batched_nms(single_proposals[ori_index],
                                       single_scores[ori_index],
                                       torch.ones(len(ori_index)),
                                       self.cache_dict['dqs_cfg'])

            real_keep_index = ori_index[keep_idxs]

            attn_mask = torch.ones_like(dis_mask[0]).bool()
            # such a attn_mask give best result
            # If it requires to keep index i, then all cells in row or column
            #   i should be kept in `attn_mask` . For example, if
            #   `real_keep_index` = [1, 4], and `attn_mask` size = [8, 8],
            #   then all cells at rows or columns [1, 4] should be kept, and
            #   all the other cells should be masked out. So the value of
            #  `attn_mask` should be:
            #
            # target\source   0 1 2 3 4 5 6 7
            #             0 [ 0 1 0 0 1 0 0 0 ]
            #             1 [ 1 1 1 1 1 1 1 1 ]
            #             2 [ 0 1 0 0 1 0 0 0 ]
            #             3 [ 0 1 0 0 1 0 0 0 ]
            #             4 [ 1 1 1 1 1 1 1 1 ]
            #             5 [ 0 1 0 0 1 0 0 0 ]
            #             6 [ 0 1 0 0 1 0 0 0 ]
            #             7 [ 0 1 0 0 1 0 0 0 ]
            attn_mask[real_keep_index] = False
            attn_mask[:, real_keep_index] = False

            attn_mask = attn_mask[None].repeat(self.cache_dict['num_heads'], 1,
                                               1)
            attn_mask_list.append(attn_mask)
        attn_mask = torch.cat(attn_mask_list)
        self_attn_mask = copy.deepcopy(self_attn_mask)
        self_attn_mask[:, dis_start:dis_start + num_dis,
                       dis_start:dis_start + num_dis] = attn_mask
        # will be used in loss and inference
        self.cache_dict['distinct_query_mask'].append(~attn_mask)
        return self_attn_mask

    def forward(self, query: Tensor, value: Tensor, key_padding_mask: Tensor,
                self_attn_mask: Tensor, reference_points: Tensor,
                spatial_shapes: Tensor, level_start_index: Tensor,
                valid_ratios: Tensor, reg_branches: nn.ModuleList,
                **kwargs) -> Tensor:
        """Forward function of Transformer decoder.

        Args:
            query (Tensor): The input query, has shape (bs, num_queries,
                dims).
            value (Tensor): The input values, has shape (bs, num_value, dim).
            key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn`
                input. ByteTensor, has shape (bs, num_value).
            self_attn_mask (Tensor): The attention mask to prevent information
                leakage from different denoising groups, distinct queries and
                dense queries, has shape (num_queries_total,
                num_queries_total). It will be updated for distinct queries
                selection in this forward function. It is `None` when
                `self.training` is `False`.
            reference_points (Tensor): The initial reference, has shape
                (bs, num_queries, 4) with the last dimension arranged as
                (cx, cy, w, h).
            spatial_shapes (Tensor): Spatial shapes of features in all levels,
                has shape (num_levels, 2), last dimension represents (h, w).
            level_start_index (Tensor): The start index of each level.
                A tensor has shape (num_levels, ) and can be represented
                as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
            valid_ratios (Tensor): The ratios of the valid width and the valid
                height relative to the width and the height of features in all
                levels, has shape (bs, num_levels, 2).
            reg_branches: (obj:`nn.ModuleList`): Used for refining the
                regression results.

        Returns:
            tuple[Tensor]: Output queries and references of Transformer
                decoder

            - query (Tensor): Output embeddings of the last decoder, has
              shape (bs, num_queries, embed_dims) when `return_intermediate`
              is `False`. Otherwise, Intermediate output embeddings of all
              decoder layers, has shape (num_decoder_layers, bs, num_queries,
              embed_dims).
            - reference_points (Tensor): The reference of the last decoder
              layer, has shape (bs, num_queries, 4) when `return_intermediate`
              is `False`. Otherwise, Intermediate references of all decoder
              layers, has shape (1 + num_decoder_layers, bs, num_queries, 4).
              The coordinates are arranged as (cx, cy, w, h).
        """
        intermediate = []
        intermediate_reference_points = [reference_points]
        self.cache_dict['distinct_query_mask'] = []
        if self_attn_mask is None:
            self_attn_mask = torch.zeros((query.size(1), query.size(1)),
                                         device=query.device).bool()
        # shape is (batch*number_heads, num_queries, num_queries)
        self_attn_mask = self_attn_mask[None].repeat(
            len(query) * self.cache_dict['num_heads'], 1, 1)
        for layer_index, layer in enumerate(self.layers):
            if reference_points.shape[-1] == 4:
                reference_points_input = \
                    reference_points[:, :, None] * torch.cat(
                        [valid_ratios, valid_ratios], -1)[:, None]
            else:
                assert reference_points.shape[-1] == 2
                reference_points_input = \
                    reference_points[:, :, None] * valid_ratios[:, None]

            query_sine_embed = coordinate_to_encoding(
                reference_points_input[:, :, 0, :],
                num_feats=self.embed_dims // 2)
            query_pos = self.ref_point_head(query_sine_embed)

            query = layer(
                query,
                query_pos=query_pos,
                value=value,
                key_padding_mask=key_padding_mask,
                self_attn_mask=self_attn_mask,
                spatial_shapes=spatial_shapes,
                level_start_index=level_start_index,
                valid_ratios=valid_ratios,
                reference_points=reference_points_input,
                **kwargs)

            if not self.training:
                tmp = reg_branches[layer_index](query)
                assert reference_points.shape[-1] == 4
                new_reference_points = tmp + inverse_sigmoid(
                    reference_points, eps=1e-3)
                new_reference_points = new_reference_points.sigmoid()
                reference_points = new_reference_points.detach()
                if layer_index < (len(self.layers) - 1):
                    self_attn_mask = self.select_distinct_queries(
                        reference_points, query, self_attn_mask, layer_index)

            else:
                num_dense = self.cache_dict['num_dense_queries']
                tmp = reg_branches[layer_index](query[:, :-num_dense])
                tmp_dense = self.aux_reg_branches[layer_index](
                    query[:, -num_dense:])

                tmp = torch.cat([tmp, tmp_dense], dim=1)
                assert reference_points.shape[-1] == 4
                new_reference_points = tmp + inverse_sigmoid(
                    reference_points, eps=1e-3)
                new_reference_points = new_reference_points.sigmoid()
                reference_points = new_reference_points.detach()
                if layer_index < (len(self.layers) - 1):
                    self_attn_mask = self.select_distinct_queries(
                        reference_points, query, self_attn_mask, layer_index)

            if self.return_intermediate:
                intermediate.append(self.norm(query))
                intermediate_reference_points.append(new_reference_points)

        if self.return_intermediate:
            return torch.stack(intermediate), torch.stack(
                intermediate_reference_points)

        return query, reference_points