TTP / mmdet /models /layers /transformer /ddq_detr_layers.py
KyanChen's picture
Upload 1861 files
3b96cb1
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import torch
from mmcv.ops import batched_nms
from torch import Tensor, nn
from mmdet.structures.bbox import bbox_cxcywh_to_xyxy
from .deformable_detr_layers import DeformableDetrTransformerDecoder
from .utils import MLP, coordinate_to_encoding, inverse_sigmoid
class DDQTransformerDecoder(DeformableDetrTransformerDecoder):
"""Transformer decoder of DDQ."""
def _init_layers(self) -> None:
"""Initialize encoder layers."""
super()._init_layers()
self.ref_point_head = MLP(self.embed_dims * 2, self.embed_dims,
self.embed_dims, 2)
self.norm = nn.LayerNorm(self.embed_dims)
def select_distinct_queries(self, reference_points: Tensor, query: Tensor,
self_attn_mask: Tensor, layer_index):
"""Get updated `self_attn_mask` for distinct queries selection, it is
used in self attention layers of decoder.
Args:
reference_points (Tensor): The input reference of decoder,
has shape (bs, num_queries, 4) with the last dimension
arranged as (cx, cy, w, h).
query (Tensor): The input query of decoder, has shape
(bs, num_queries, dims).
self_attn_mask (Tensor): The input self attention mask of
last decoder layer, has shape (bs, num_queries_total,
num_queries_total).
layer_index (int): Last decoder layer index, used to get
classification score of last layer output, for
distinct queries selection.
Returns:
Tensor: `self_attn_mask` used in self attention layers
of decoder, has shape (bs, num_queries_total,
num_queries_total).
"""
num_imgs = len(reference_points)
dis_start, num_dis = self.cache_dict['dis_query_info']
# shape of self_attn_mask
# (batch⋅num_heads, num_queries, embed_dims)
dis_mask = self_attn_mask[:, dis_start:dis_start + num_dis,
dis_start:dis_start + num_dis]
# cls_branches from DDQDETRHead
scores = self.cache_dict['cls_branches'][layer_index](
query[:, dis_start:dis_start + num_dis]).sigmoid().max(-1).values
proposals = reference_points[:, dis_start:dis_start + num_dis]
proposals = bbox_cxcywh_to_xyxy(proposals)
attn_mask_list = []
for img_id in range(num_imgs):
single_proposals = proposals[img_id]
single_scores = scores[img_id]
attn_mask = ~dis_mask[img_id * self.cache_dict['num_heads']][0]
# distinct query inds in this layer
ori_index = attn_mask.nonzero().view(-1)
_, keep_idxs = batched_nms(single_proposals[ori_index],
single_scores[ori_index],
torch.ones(len(ori_index)),
self.cache_dict['dqs_cfg'])
real_keep_index = ori_index[keep_idxs]
attn_mask = torch.ones_like(dis_mask[0]).bool()
# such a attn_mask give best result
# If it requires to keep index i, then all cells in row or column
# i should be kept in `attn_mask` . For example, if
# `real_keep_index` = [1, 4], and `attn_mask` size = [8, 8],
# then all cells at rows or columns [1, 4] should be kept, and
# all the other cells should be masked out. So the value of
# `attn_mask` should be:
#
# target\source 0 1 2 3 4 5 6 7
# 0 [ 0 1 0 0 1 0 0 0 ]
# 1 [ 1 1 1 1 1 1 1 1 ]
# 2 [ 0 1 0 0 1 0 0 0 ]
# 3 [ 0 1 0 0 1 0 0 0 ]
# 4 [ 1 1 1 1 1 1 1 1 ]
# 5 [ 0 1 0 0 1 0 0 0 ]
# 6 [ 0 1 0 0 1 0 0 0 ]
# 7 [ 0 1 0 0 1 0 0 0 ]
attn_mask[real_keep_index] = False
attn_mask[:, real_keep_index] = False
attn_mask = attn_mask[None].repeat(self.cache_dict['num_heads'], 1,
1)
attn_mask_list.append(attn_mask)
attn_mask = torch.cat(attn_mask_list)
self_attn_mask = copy.deepcopy(self_attn_mask)
self_attn_mask[:, dis_start:dis_start + num_dis,
dis_start:dis_start + num_dis] = attn_mask
# will be used in loss and inference
self.cache_dict['distinct_query_mask'].append(~attn_mask)
return self_attn_mask
def forward(self, query: Tensor, value: Tensor, key_padding_mask: Tensor,
self_attn_mask: Tensor, reference_points: Tensor,
spatial_shapes: Tensor, level_start_index: Tensor,
valid_ratios: Tensor, reg_branches: nn.ModuleList,
**kwargs) -> Tensor:
"""Forward function of Transformer decoder.
Args:
query (Tensor): The input query, has shape (bs, num_queries,
dims).
value (Tensor): The input values, has shape (bs, num_value, dim).
key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn`
input. ByteTensor, has shape (bs, num_value).
self_attn_mask (Tensor): The attention mask to prevent information
leakage from different denoising groups, distinct queries and
dense queries, has shape (num_queries_total,
num_queries_total). It will be updated for distinct queries
selection in this forward function. It is `None` when
`self.training` is `False`.
reference_points (Tensor): The initial reference, has shape
(bs, num_queries, 4) with the last dimension arranged as
(cx, cy, w, h).
spatial_shapes (Tensor): Spatial shapes of features in all levels,
has shape (num_levels, 2), last dimension represents (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape (num_levels, ) and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
valid_ratios (Tensor): The ratios of the valid width and the valid
height relative to the width and the height of features in all
levels, has shape (bs, num_levels, 2).
reg_branches: (obj:`nn.ModuleList`): Used for refining the
regression results.
Returns:
tuple[Tensor]: Output queries and references of Transformer
decoder
- query (Tensor): Output embeddings of the last decoder, has
shape (bs, num_queries, embed_dims) when `return_intermediate`
is `False`. Otherwise, Intermediate output embeddings of all
decoder layers, has shape (num_decoder_layers, bs, num_queries,
embed_dims).
- reference_points (Tensor): The reference of the last decoder
layer, has shape (bs, num_queries, 4) when `return_intermediate`
is `False`. Otherwise, Intermediate references of all decoder
layers, has shape (1 + num_decoder_layers, bs, num_queries, 4).
The coordinates are arranged as (cx, cy, w, h).
"""
intermediate = []
intermediate_reference_points = [reference_points]
self.cache_dict['distinct_query_mask'] = []
if self_attn_mask is None:
self_attn_mask = torch.zeros((query.size(1), query.size(1)),
device=query.device).bool()
# shape is (batch*number_heads, num_queries, num_queries)
self_attn_mask = self_attn_mask[None].repeat(
len(query) * self.cache_dict['num_heads'], 1, 1)
for layer_index, layer in enumerate(self.layers):
if reference_points.shape[-1] == 4:
reference_points_input = \
reference_points[:, :, None] * torch.cat(
[valid_ratios, valid_ratios], -1)[:, None]
else:
assert reference_points.shape[-1] == 2
reference_points_input = \
reference_points[:, :, None] * valid_ratios[:, None]
query_sine_embed = coordinate_to_encoding(
reference_points_input[:, :, 0, :],
num_feats=self.embed_dims // 2)
query_pos = self.ref_point_head(query_sine_embed)
query = layer(
query,
query_pos=query_pos,
value=value,
key_padding_mask=key_padding_mask,
self_attn_mask=self_attn_mask,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
reference_points=reference_points_input,
**kwargs)
if not self.training:
tmp = reg_branches[layer_index](query)
assert reference_points.shape[-1] == 4
new_reference_points = tmp + inverse_sigmoid(
reference_points, eps=1e-3)
new_reference_points = new_reference_points.sigmoid()
reference_points = new_reference_points.detach()
if layer_index < (len(self.layers) - 1):
self_attn_mask = self.select_distinct_queries(
reference_points, query, self_attn_mask, layer_index)
else:
num_dense = self.cache_dict['num_dense_queries']
tmp = reg_branches[layer_index](query[:, :-num_dense])
tmp_dense = self.aux_reg_branches[layer_index](
query[:, -num_dense:])
tmp = torch.cat([tmp, tmp_dense], dim=1)
assert reference_points.shape[-1] == 4
new_reference_points = tmp + inverse_sigmoid(
reference_points, eps=1e-3)
new_reference_points = new_reference_points.sigmoid()
reference_points = new_reference_points.detach()
if layer_index < (len(self.layers) - 1):
self_attn_mask = self.select_distinct_queries(
reference_points, query, self_attn_mask, layer_index)
if self.return_intermediate:
intermediate.append(self.norm(query))
intermediate_reference_points.append(new_reference_points)
if self.return_intermediate:
return torch.stack(intermediate), torch.stack(
intermediate_reference_points)
return query, reference_points