|
|
|
from typing import Optional, Tuple, Union |
|
|
|
import torch |
|
from mmcv.cnn import build_norm_layer |
|
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention |
|
from mmcv.ops import MultiScaleDeformableAttention |
|
from mmengine.model import ModuleList |
|
from torch import Tensor, nn |
|
|
|
from .detr_layers import (DetrTransformerDecoder, DetrTransformerDecoderLayer, |
|
DetrTransformerEncoder, DetrTransformerEncoderLayer) |
|
from .utils import inverse_sigmoid |
|
|
|
|
|
class DeformableDetrTransformerEncoder(DetrTransformerEncoder): |
|
"""Transformer encoder of Deformable DETR.""" |
|
|
|
def _init_layers(self) -> None: |
|
"""Initialize encoder layers.""" |
|
self.layers = ModuleList([ |
|
DeformableDetrTransformerEncoderLayer(**self.layer_cfg) |
|
for _ in range(self.num_layers) |
|
]) |
|
self.embed_dims = self.layers[0].embed_dims |
|
|
|
def forward(self, query: Tensor, query_pos: Tensor, |
|
key_padding_mask: Tensor, spatial_shapes: Tensor, |
|
level_start_index: Tensor, valid_ratios: Tensor, |
|
**kwargs) -> Tensor: |
|
"""Forward function of Transformer encoder. |
|
|
|
Args: |
|
query (Tensor): The input query, has shape (bs, num_queries, dim). |
|
query_pos (Tensor): The positional encoding for query, has shape |
|
(bs, num_queries, dim). |
|
key_padding_mask (Tensor): The `key_padding_mask` of `self_attn` |
|
input. ByteTensor, has shape (bs, num_queries). |
|
spatial_shapes (Tensor): Spatial shapes of features in all levels, |
|
has shape (num_levels, 2), last dimension represents (h, w). |
|
level_start_index (Tensor): The start index of each level. |
|
A tensor has shape (num_levels, ) and can be represented |
|
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. |
|
valid_ratios (Tensor): The ratios of the valid width and the valid |
|
height relative to the width and the height of features in all |
|
levels, has shape (bs, num_levels, 2). |
|
|
|
Returns: |
|
Tensor: Output queries of Transformer encoder, which is also |
|
called 'encoder output embeddings' or 'memory', has shape |
|
(bs, num_queries, dim) |
|
""" |
|
reference_points = self.get_encoder_reference_points( |
|
spatial_shapes, valid_ratios, device=query.device) |
|
for layer in self.layers: |
|
query = layer( |
|
query=query, |
|
query_pos=query_pos, |
|
key_padding_mask=key_padding_mask, |
|
spatial_shapes=spatial_shapes, |
|
level_start_index=level_start_index, |
|
valid_ratios=valid_ratios, |
|
reference_points=reference_points, |
|
**kwargs) |
|
return query |
|
|
|
@staticmethod |
|
def get_encoder_reference_points( |
|
spatial_shapes: Tensor, valid_ratios: Tensor, |
|
device: Union[torch.device, str]) -> Tensor: |
|
"""Get the reference points used in encoder. |
|
|
|
Args: |
|
spatial_shapes (Tensor): Spatial shapes of features in all levels, |
|
has shape (num_levels, 2), last dimension represents (h, w). |
|
valid_ratios (Tensor): The ratios of the valid width and the valid |
|
height relative to the width and the height of features in all |
|
levels, has shape (bs, num_levels, 2). |
|
device (obj:`device` or str): The device acquired by the |
|
`reference_points`. |
|
|
|
Returns: |
|
Tensor: Reference points used in decoder, has shape (bs, length, |
|
num_levels, 2). |
|
""" |
|
|
|
reference_points_list = [] |
|
for lvl, (H, W) in enumerate(spatial_shapes): |
|
ref_y, ref_x = torch.meshgrid( |
|
torch.linspace( |
|
0.5, H - 0.5, H, dtype=torch.float32, device=device), |
|
torch.linspace( |
|
0.5, W - 0.5, W, dtype=torch.float32, device=device)) |
|
ref_y = ref_y.reshape(-1)[None] / ( |
|
valid_ratios[:, None, lvl, 1] * H) |
|
ref_x = ref_x.reshape(-1)[None] / ( |
|
valid_ratios[:, None, lvl, 0] * W) |
|
ref = torch.stack((ref_x, ref_y), -1) |
|
reference_points_list.append(ref) |
|
reference_points = torch.cat(reference_points_list, 1) |
|
|
|
reference_points = reference_points[:, :, None] * valid_ratios[:, None] |
|
return reference_points |
|
|
|
|
|
class DeformableDetrTransformerDecoder(DetrTransformerDecoder): |
|
"""Transformer Decoder of Deformable DETR.""" |
|
|
|
def _init_layers(self) -> None: |
|
"""Initialize decoder layers.""" |
|
self.layers = ModuleList([ |
|
DeformableDetrTransformerDecoderLayer(**self.layer_cfg) |
|
for _ in range(self.num_layers) |
|
]) |
|
self.embed_dims = self.layers[0].embed_dims |
|
if self.post_norm_cfg is not None: |
|
raise ValueError('There is not post_norm in ' |
|
f'{self._get_name()}') |
|
|
|
def forward(self, |
|
query: Tensor, |
|
query_pos: Tensor, |
|
value: Tensor, |
|
key_padding_mask: Tensor, |
|
reference_points: Tensor, |
|
spatial_shapes: Tensor, |
|
level_start_index: Tensor, |
|
valid_ratios: Tensor, |
|
reg_branches: Optional[nn.Module] = None, |
|
**kwargs) -> Tuple[Tensor]: |
|
"""Forward function of Transformer decoder. |
|
|
|
Args: |
|
query (Tensor): The input queries, has shape (bs, num_queries, |
|
dim). |
|
query_pos (Tensor): The input positional query, has shape |
|
(bs, num_queries, dim). It will be added to `query` before |
|
forward function. |
|
value (Tensor): The input values, has shape (bs, num_value, dim). |
|
key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn` |
|
input. ByteTensor, has shape (bs, num_value). |
|
reference_points (Tensor): The initial reference, has shape |
|
(bs, num_queries, 4) with the last dimension arranged as |
|
(cx, cy, w, h) when `as_two_stage` is `True`, otherwise has |
|
shape (bs, num_queries, 2) with the last dimension arranged |
|
as (cx, cy). |
|
spatial_shapes (Tensor): Spatial shapes of features in all levels, |
|
has shape (num_levels, 2), last dimension represents (h, w). |
|
level_start_index (Tensor): The start index of each level. |
|
A tensor has shape (num_levels, ) and can be represented |
|
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. |
|
valid_ratios (Tensor): The ratios of the valid width and the valid |
|
height relative to the width and the height of features in all |
|
levels, has shape (bs, num_levels, 2). |
|
reg_branches: (obj:`nn.ModuleList`, optional): Used for refining |
|
the regression results. Only would be passed when |
|
`with_box_refine` is `True`, otherwise would be `None`. |
|
|
|
Returns: |
|
tuple[Tensor]: Outputs of Deformable Transformer Decoder. |
|
|
|
- output (Tensor): Output embeddings of the last decoder, has |
|
shape (num_queries, bs, embed_dims) when `return_intermediate` |
|
is `False`. Otherwise, Intermediate output embeddings of all |
|
decoder layers, has shape (num_decoder_layers, num_queries, bs, |
|
embed_dims). |
|
- reference_points (Tensor): The reference of the last decoder |
|
layer, has shape (bs, num_queries, 4) when `return_intermediate` |
|
is `False`. Otherwise, Intermediate references of all decoder |
|
layers, has shape (num_decoder_layers, bs, num_queries, 4). The |
|
coordinates are arranged as (cx, cy, w, h) |
|
""" |
|
output = query |
|
intermediate = [] |
|
intermediate_reference_points = [] |
|
for layer_id, layer in enumerate(self.layers): |
|
if reference_points.shape[-1] == 4: |
|
reference_points_input = \ |
|
reference_points[:, :, None] * \ |
|
torch.cat([valid_ratios, valid_ratios], -1)[:, None] |
|
else: |
|
assert reference_points.shape[-1] == 2 |
|
reference_points_input = \ |
|
reference_points[:, :, None] * \ |
|
valid_ratios[:, None] |
|
output = layer( |
|
output, |
|
query_pos=query_pos, |
|
value=value, |
|
key_padding_mask=key_padding_mask, |
|
spatial_shapes=spatial_shapes, |
|
level_start_index=level_start_index, |
|
valid_ratios=valid_ratios, |
|
reference_points=reference_points_input, |
|
**kwargs) |
|
|
|
if reg_branches is not None: |
|
tmp_reg_preds = reg_branches[layer_id](output) |
|
if reference_points.shape[-1] == 4: |
|
new_reference_points = tmp_reg_preds + inverse_sigmoid( |
|
reference_points) |
|
new_reference_points = new_reference_points.sigmoid() |
|
else: |
|
assert reference_points.shape[-1] == 2 |
|
new_reference_points = tmp_reg_preds |
|
new_reference_points[..., :2] = tmp_reg_preds[ |
|
..., :2] + inverse_sigmoid(reference_points) |
|
new_reference_points = new_reference_points.sigmoid() |
|
reference_points = new_reference_points.detach() |
|
|
|
if self.return_intermediate: |
|
intermediate.append(output) |
|
intermediate_reference_points.append(reference_points) |
|
|
|
if self.return_intermediate: |
|
return torch.stack(intermediate), torch.stack( |
|
intermediate_reference_points) |
|
|
|
return output, reference_points |
|
|
|
|
|
class DeformableDetrTransformerEncoderLayer(DetrTransformerEncoderLayer): |
|
"""Encoder layer of Deformable DETR.""" |
|
|
|
def _init_layers(self) -> None: |
|
"""Initialize self_attn, ffn, and norms.""" |
|
self.self_attn = MultiScaleDeformableAttention(**self.self_attn_cfg) |
|
self.embed_dims = self.self_attn.embed_dims |
|
self.ffn = FFN(**self.ffn_cfg) |
|
norms_list = [ |
|
build_norm_layer(self.norm_cfg, self.embed_dims)[1] |
|
for _ in range(2) |
|
] |
|
self.norms = ModuleList(norms_list) |
|
|
|
|
|
class DeformableDetrTransformerDecoderLayer(DetrTransformerDecoderLayer): |
|
"""Decoder layer of Deformable DETR.""" |
|
|
|
def _init_layers(self) -> None: |
|
"""Initialize self_attn, cross-attn, ffn, and norms.""" |
|
self.self_attn = MultiheadAttention(**self.self_attn_cfg) |
|
self.cross_attn = MultiScaleDeformableAttention(**self.cross_attn_cfg) |
|
self.embed_dims = self.self_attn.embed_dims |
|
self.ffn = FFN(**self.ffn_cfg) |
|
norms_list = [ |
|
build_norm_layer(self.norm_cfg, self.embed_dims)[1] |
|
for _ in range(3) |
|
] |
|
self.norms = ModuleList(norms_list) |
|
|