|
|
|
from typing import List |
|
|
|
import torch |
|
import torch.nn as nn |
|
from mmcv.cnn import build_norm_layer |
|
from mmcv.cnn.bricks.transformer import FFN |
|
from mmengine.model import ModuleList |
|
from torch import Tensor |
|
|
|
from .detr_layers import (DetrTransformerDecoder, DetrTransformerDecoderLayer, |
|
DetrTransformerEncoder, DetrTransformerEncoderLayer) |
|
from .utils import (MLP, ConditionalAttention, coordinate_to_encoding, |
|
inverse_sigmoid) |
|
|
|
|
|
class DABDetrTransformerDecoderLayer(DetrTransformerDecoderLayer): |
|
"""Implements decoder layer in DAB-DETR transformer.""" |
|
|
|
def _init_layers(self): |
|
"""Initialize self-attention, cross-attention, FFN, normalization and |
|
others.""" |
|
self.self_attn = ConditionalAttention(**self.self_attn_cfg) |
|
self.cross_attn = ConditionalAttention(**self.cross_attn_cfg) |
|
self.embed_dims = self.self_attn.embed_dims |
|
self.ffn = FFN(**self.ffn_cfg) |
|
norms_list = [ |
|
build_norm_layer(self.norm_cfg, self.embed_dims)[1] |
|
for _ in range(3) |
|
] |
|
self.norms = ModuleList(norms_list) |
|
self.keep_query_pos = self.cross_attn.keep_query_pos |
|
|
|
def forward(self, |
|
query: Tensor, |
|
key: Tensor, |
|
query_pos: Tensor, |
|
key_pos: Tensor, |
|
ref_sine_embed: Tensor = None, |
|
self_attn_masks: Tensor = None, |
|
cross_attn_masks: Tensor = None, |
|
key_padding_mask: Tensor = None, |
|
is_first: bool = False, |
|
**kwargs) -> Tensor: |
|
""" |
|
Args: |
|
query (Tensor): The input query with shape [bs, num_queries, |
|
dim]. |
|
key (Tensor): The key tensor with shape [bs, num_keys, |
|
dim]. |
|
query_pos (Tensor): The positional encoding for query in self |
|
attention, with the same shape as `x`. |
|
key_pos (Tensor): The positional encoding for `key`, with the |
|
same shape as `key`. |
|
ref_sine_embed (Tensor): The positional encoding for query in |
|
cross attention, with the same shape as `x`. |
|
Defaults to None. |
|
self_attn_masks (Tensor): ByteTensor mask with shape [num_queries, |
|
num_keys]. Same in `nn.MultiheadAttention.forward`. |
|
Defaults to None. |
|
cross_attn_masks (Tensor): ByteTensor mask with shape [num_queries, |
|
num_keys]. Same in `nn.MultiheadAttention.forward`. |
|
Defaults to None. |
|
key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys]. |
|
Defaults to None. |
|
is_first (bool): A indicator to tell whether the current layer |
|
is the first layer of the decoder. |
|
Defaults to False. |
|
|
|
Returns: |
|
Tensor: forwarded results with shape |
|
[bs, num_queries, dim]. |
|
""" |
|
|
|
query = self.self_attn( |
|
query=query, |
|
key=query, |
|
query_pos=query_pos, |
|
key_pos=query_pos, |
|
attn_mask=self_attn_masks, |
|
**kwargs) |
|
query = self.norms[0](query) |
|
query = self.cross_attn( |
|
query=query, |
|
key=key, |
|
query_pos=query_pos, |
|
key_pos=key_pos, |
|
ref_sine_embed=ref_sine_embed, |
|
attn_mask=cross_attn_masks, |
|
key_padding_mask=key_padding_mask, |
|
is_first=is_first, |
|
**kwargs) |
|
query = self.norms[1](query) |
|
query = self.ffn(query) |
|
query = self.norms[2](query) |
|
|
|
return query |
|
|
|
|
|
class DABDetrTransformerDecoder(DetrTransformerDecoder): |
|
"""Decoder of DAB-DETR. |
|
|
|
Args: |
|
query_dim (int): The last dimension of query pos, |
|
4 for anchor format, 2 for point format. |
|
Defaults to 4. |
|
query_scale_type (str): Type of transformation applied |
|
to content query. Defaults to `cond_elewise`. |
|
with_modulated_hw_attn (bool): Whether to inject h&w info |
|
during cross conditional attention. Defaults to True. |
|
""" |
|
|
|
def __init__(self, |
|
*args, |
|
query_dim: int = 4, |
|
query_scale_type: str = 'cond_elewise', |
|
with_modulated_hw_attn: bool = True, |
|
**kwargs): |
|
|
|
self.query_dim = query_dim |
|
self.query_scale_type = query_scale_type |
|
self.with_modulated_hw_attn = with_modulated_hw_attn |
|
|
|
super().__init__(*args, **kwargs) |
|
|
|
def _init_layers(self): |
|
"""Initialize decoder layers and other layers.""" |
|
assert self.query_dim in [2, 4], \ |
|
f'{"dab-detr only supports anchor prior or reference point prior"}' |
|
assert self.query_scale_type in [ |
|
'cond_elewise', 'cond_scalar', 'fix_elewise' |
|
] |
|
|
|
self.layers = ModuleList([ |
|
DABDetrTransformerDecoderLayer(**self.layer_cfg) |
|
for _ in range(self.num_layers) |
|
]) |
|
|
|
embed_dims = self.layers[0].embed_dims |
|
self.embed_dims = embed_dims |
|
|
|
self.post_norm = build_norm_layer(self.post_norm_cfg, embed_dims)[1] |
|
if self.query_scale_type == 'cond_elewise': |
|
self.query_scale = MLP(embed_dims, embed_dims, embed_dims, 2) |
|
elif self.query_scale_type == 'cond_scalar': |
|
self.query_scale = MLP(embed_dims, embed_dims, 1, 2) |
|
elif self.query_scale_type == 'fix_elewise': |
|
self.query_scale = nn.Embedding(self.num_layers, embed_dims) |
|
else: |
|
raise NotImplementedError('Unknown query_scale_type: {}'.format( |
|
self.query_scale_type)) |
|
|
|
self.ref_point_head = MLP(self.query_dim // 2 * embed_dims, embed_dims, |
|
embed_dims, 2) |
|
|
|
if self.with_modulated_hw_attn and self.query_dim == 4: |
|
self.ref_anchor_head = MLP(embed_dims, embed_dims, 2, 2) |
|
|
|
self.keep_query_pos = self.layers[0].keep_query_pos |
|
if not self.keep_query_pos: |
|
for layer_id in range(self.num_layers - 1): |
|
self.layers[layer_id + 1].cross_attn.qpos_proj = None |
|
|
|
def forward(self, |
|
query: Tensor, |
|
key: Tensor, |
|
query_pos: Tensor, |
|
key_pos: Tensor, |
|
reg_branches: nn.Module, |
|
key_padding_mask: Tensor = None, |
|
**kwargs) -> List[Tensor]: |
|
"""Forward function of decoder. |
|
|
|
Args: |
|
query (Tensor): The input query with shape (bs, num_queries, dim). |
|
key (Tensor): The input key with shape (bs, num_keys, dim). |
|
query_pos (Tensor): The positional encoding for `query`, with the |
|
same shape as `query`. |
|
key_pos (Tensor): The positional encoding for `key`, with the |
|
same shape as `key`. |
|
reg_branches (nn.Module): The regression branch for dynamically |
|
updating references in each layer. |
|
key_padding_mask (Tensor): ByteTensor with shape (bs, num_keys). |
|
Defaults to `None`. |
|
|
|
Returns: |
|
List[Tensor]: forwarded results with shape (num_decoder_layers, |
|
bs, num_queries, dim) if `return_intermediate` is True, otherwise |
|
with shape (1, bs, num_queries, dim). references with shape |
|
(num_decoder_layers, bs, num_queries, 2/4). |
|
""" |
|
output = query |
|
unsigmoid_references = query_pos |
|
|
|
reference_points = unsigmoid_references.sigmoid() |
|
intermediate_reference_points = [reference_points] |
|
|
|
intermediate = [] |
|
for layer_id, layer in enumerate(self.layers): |
|
obj_center = reference_points[..., :self.query_dim] |
|
ref_sine_embed = coordinate_to_encoding( |
|
coord_tensor=obj_center, num_feats=self.embed_dims // 2) |
|
query_pos = self.ref_point_head( |
|
ref_sine_embed) |
|
|
|
if self.query_scale_type != 'fix_elewise': |
|
if layer_id == 0: |
|
pos_transformation = 1 |
|
else: |
|
pos_transformation = self.query_scale(output) |
|
else: |
|
pos_transformation = self.query_scale.weight[layer_id] |
|
|
|
ref_sine_embed = ref_sine_embed[ |
|
..., :self.embed_dims] * pos_transformation |
|
|
|
if self.with_modulated_hw_attn: |
|
assert obj_center.size(-1) == 4 |
|
ref_hw = self.ref_anchor_head(output).sigmoid() |
|
ref_sine_embed[..., self.embed_dims // 2:] *= \ |
|
(ref_hw[..., 0] / obj_center[..., 2]).unsqueeze(-1) |
|
ref_sine_embed[..., : self.embed_dims // 2] *= \ |
|
(ref_hw[..., 1] / obj_center[..., 3]).unsqueeze(-1) |
|
|
|
output = layer( |
|
output, |
|
key, |
|
query_pos=query_pos, |
|
ref_sine_embed=ref_sine_embed, |
|
key_pos=key_pos, |
|
key_padding_mask=key_padding_mask, |
|
is_first=(layer_id == 0), |
|
**kwargs) |
|
|
|
tmp_reg_preds = reg_branches(output) |
|
tmp_reg_preds[..., :self.query_dim] += inverse_sigmoid( |
|
reference_points) |
|
new_reference_points = tmp_reg_preds[ |
|
..., :self.query_dim].sigmoid() |
|
if layer_id != self.num_layers - 1: |
|
intermediate_reference_points.append(new_reference_points) |
|
reference_points = new_reference_points.detach() |
|
|
|
if self.return_intermediate: |
|
intermediate.append(self.post_norm(output)) |
|
|
|
output = self.post_norm(output) |
|
|
|
if self.return_intermediate: |
|
return [ |
|
torch.stack(intermediate), |
|
torch.stack(intermediate_reference_points), |
|
] |
|
else: |
|
return [ |
|
output.unsqueeze(0), |
|
torch.stack(intermediate_reference_points) |
|
] |
|
|
|
|
|
class DABDetrTransformerEncoder(DetrTransformerEncoder): |
|
"""Encoder of DAB-DETR.""" |
|
|
|
def _init_layers(self): |
|
"""Initialize encoder layers.""" |
|
self.layers = ModuleList([ |
|
DetrTransformerEncoderLayer(**self.layer_cfg) |
|
for _ in range(self.num_layers) |
|
]) |
|
embed_dims = self.layers[0].embed_dims |
|
self.embed_dims = embed_dims |
|
self.query_scale = MLP(embed_dims, embed_dims, embed_dims, 2) |
|
|
|
def forward(self, query: Tensor, query_pos: Tensor, |
|
key_padding_mask: Tensor, **kwargs): |
|
"""Forward function of encoder. |
|
|
|
Args: |
|
query (Tensor): Input queries of encoder, has shape |
|
(bs, num_queries, dim). |
|
query_pos (Tensor): The positional embeddings of the queries, has |
|
shape (bs, num_feat_points, dim). |
|
key_padding_mask (Tensor): ByteTensor, the key padding mask |
|
of the queries, has shape (bs, num_feat_points). |
|
|
|
Returns: |
|
Tensor: With shape (num_queries, bs, dim). |
|
""" |
|
|
|
for layer in self.layers: |
|
pos_scales = self.query_scale(query) |
|
query = layer( |
|
query, |
|
query_pos=query_pos * pos_scales, |
|
key_padding_mask=key_padding_mask, |
|
**kwargs) |
|
|
|
return query |
|
|