# Copyright (c) OpenMMLab. All rights reserved. import copy import torch from mmcv.ops import batched_nms from torch import Tensor, nn from mmdet.structures.bbox import bbox_cxcywh_to_xyxy from .deformable_detr_layers import DeformableDetrTransformerDecoder from .utils import MLP, coordinate_to_encoding, inverse_sigmoid class DDQTransformerDecoder(DeformableDetrTransformerDecoder): """Transformer decoder of DDQ.""" def _init_layers(self) -> None: """Initialize encoder layers.""" super()._init_layers() self.ref_point_head = MLP(self.embed_dims * 2, self.embed_dims, self.embed_dims, 2) self.norm = nn.LayerNorm(self.embed_dims) def select_distinct_queries(self, reference_points: Tensor, query: Tensor, self_attn_mask: Tensor, layer_index): """Get updated `self_attn_mask` for distinct queries selection, it is used in self attention layers of decoder. Args: reference_points (Tensor): The input reference of decoder, has shape (bs, num_queries, 4) with the last dimension arranged as (cx, cy, w, h). query (Tensor): The input query of decoder, has shape (bs, num_queries, dims). self_attn_mask (Tensor): The input self attention mask of last decoder layer, has shape (bs, num_queries_total, num_queries_total). layer_index (int): Last decoder layer index, used to get classification score of last layer output, for distinct queries selection. Returns: Tensor: `self_attn_mask` used in self attention layers of decoder, has shape (bs, num_queries_total, num_queries_total). """ num_imgs = len(reference_points) dis_start, num_dis = self.cache_dict['dis_query_info'] # shape of self_attn_mask # (batchâ‹…num_heads, num_queries, embed_dims) dis_mask = self_attn_mask[:, dis_start:dis_start + num_dis, dis_start:dis_start + num_dis] # cls_branches from DDQDETRHead scores = self.cache_dict['cls_branches'][layer_index]( query[:, dis_start:dis_start + num_dis]).sigmoid().max(-1).values proposals = reference_points[:, dis_start:dis_start + num_dis] proposals = bbox_cxcywh_to_xyxy(proposals) attn_mask_list = [] for img_id in range(num_imgs): single_proposals = proposals[img_id] single_scores = scores[img_id] attn_mask = ~dis_mask[img_id * self.cache_dict['num_heads']][0] # distinct query inds in this layer ori_index = attn_mask.nonzero().view(-1) _, keep_idxs = batched_nms(single_proposals[ori_index], single_scores[ori_index], torch.ones(len(ori_index)), self.cache_dict['dqs_cfg']) real_keep_index = ori_index[keep_idxs] attn_mask = torch.ones_like(dis_mask[0]).bool() # such a attn_mask give best result # If it requires to keep index i, then all cells in row or column # i should be kept in `attn_mask` . For example, if # `real_keep_index` = [1, 4], and `attn_mask` size = [8, 8], # then all cells at rows or columns [1, 4] should be kept, and # all the other cells should be masked out. So the value of # `attn_mask` should be: # # target\source 0 1 2 3 4 5 6 7 # 0 [ 0 1 0 0 1 0 0 0 ] # 1 [ 1 1 1 1 1 1 1 1 ] # 2 [ 0 1 0 0 1 0 0 0 ] # 3 [ 0 1 0 0 1 0 0 0 ] # 4 [ 1 1 1 1 1 1 1 1 ] # 5 [ 0 1 0 0 1 0 0 0 ] # 6 [ 0 1 0 0 1 0 0 0 ] # 7 [ 0 1 0 0 1 0 0 0 ] attn_mask[real_keep_index] = False attn_mask[:, real_keep_index] = False attn_mask = attn_mask[None].repeat(self.cache_dict['num_heads'], 1, 1) attn_mask_list.append(attn_mask) attn_mask = torch.cat(attn_mask_list) self_attn_mask = copy.deepcopy(self_attn_mask) self_attn_mask[:, dis_start:dis_start + num_dis, dis_start:dis_start + num_dis] = attn_mask # will be used in loss and inference self.cache_dict['distinct_query_mask'].append(~attn_mask) return self_attn_mask def forward(self, query: Tensor, value: Tensor, key_padding_mask: Tensor, self_attn_mask: Tensor, reference_points: Tensor, spatial_shapes: Tensor, level_start_index: Tensor, valid_ratios: Tensor, reg_branches: nn.ModuleList, **kwargs) -> Tensor: """Forward function of Transformer decoder. Args: query (Tensor): The input query, has shape (bs, num_queries, dims). value (Tensor): The input values, has shape (bs, num_value, dim). key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn` input. ByteTensor, has shape (bs, num_value). self_attn_mask (Tensor): The attention mask to prevent information leakage from different denoising groups, distinct queries and dense queries, has shape (num_queries_total, num_queries_total). It will be updated for distinct queries selection in this forward function. It is `None` when `self.training` is `False`. reference_points (Tensor): The initial reference, has shape (bs, num_queries, 4) with the last dimension arranged as (cx, cy, w, h). spatial_shapes (Tensor): Spatial shapes of features in all levels, has shape (num_levels, 2), last dimension represents (h, w). level_start_index (Tensor): The start index of each level. A tensor has shape (num_levels, ) and can be represented as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. valid_ratios (Tensor): The ratios of the valid width and the valid height relative to the width and the height of features in all levels, has shape (bs, num_levels, 2). reg_branches: (obj:`nn.ModuleList`): Used for refining the regression results. Returns: tuple[Tensor]: Output queries and references of Transformer decoder - query (Tensor): Output embeddings of the last decoder, has shape (bs, num_queries, embed_dims) when `return_intermediate` is `False`. Otherwise, Intermediate output embeddings of all decoder layers, has shape (num_decoder_layers, bs, num_queries, embed_dims). - reference_points (Tensor): The reference of the last decoder layer, has shape (bs, num_queries, 4) when `return_intermediate` is `False`. Otherwise, Intermediate references of all decoder layers, has shape (1 + num_decoder_layers, bs, num_queries, 4). The coordinates are arranged as (cx, cy, w, h). """ intermediate = [] intermediate_reference_points = [reference_points] self.cache_dict['distinct_query_mask'] = [] if self_attn_mask is None: self_attn_mask = torch.zeros((query.size(1), query.size(1)), device=query.device).bool() # shape is (batch*number_heads, num_queries, num_queries) self_attn_mask = self_attn_mask[None].repeat( len(query) * self.cache_dict['num_heads'], 1, 1) for layer_index, layer in enumerate(self.layers): if reference_points.shape[-1] == 4: reference_points_input = \ reference_points[:, :, None] * torch.cat( [valid_ratios, valid_ratios], -1)[:, None] else: assert reference_points.shape[-1] == 2 reference_points_input = \ reference_points[:, :, None] * valid_ratios[:, None] query_sine_embed = coordinate_to_encoding( reference_points_input[:, :, 0, :], num_feats=self.embed_dims // 2) query_pos = self.ref_point_head(query_sine_embed) query = layer( query, query_pos=query_pos, value=value, key_padding_mask=key_padding_mask, self_attn_mask=self_attn_mask, spatial_shapes=spatial_shapes, level_start_index=level_start_index, valid_ratios=valid_ratios, reference_points=reference_points_input, **kwargs) if not self.training: tmp = reg_branches[layer_index](query) assert reference_points.shape[-1] == 4 new_reference_points = tmp + inverse_sigmoid( reference_points, eps=1e-3) new_reference_points = new_reference_points.sigmoid() reference_points = new_reference_points.detach() if layer_index < (len(self.layers) - 1): self_attn_mask = self.select_distinct_queries( reference_points, query, self_attn_mask, layer_index) else: num_dense = self.cache_dict['num_dense_queries'] tmp = reg_branches[layer_index](query[:, :-num_dense]) tmp_dense = self.aux_reg_branches[layer_index]( query[:, -num_dense:]) tmp = torch.cat([tmp, tmp_dense], dim=1) assert reference_points.shape[-1] == 4 new_reference_points = tmp + inverse_sigmoid( reference_points, eps=1e-3) new_reference_points = new_reference_points.sigmoid() reference_points = new_reference_points.detach() if layer_index < (len(self.layers) - 1): self_attn_mask = self.select_distinct_queries( reference_points, query, self_attn_mask, layer_index) if self.return_intermediate: intermediate.append(self.norm(query)) intermediate_reference_points.append(new_reference_points) if self.return_intermediate: return torch.stack(intermediate), torch.stack( intermediate_reference_points) return query, reference_points