|
|
|
from mmcv.cnn import build_norm_layer |
|
from mmengine.model import ModuleList |
|
from torch import Tensor |
|
|
|
from .deformable_detr_layers import DeformableDetrTransformerEncoder |
|
from .detr_layers import DetrTransformerDecoder, DetrTransformerDecoderLayer |
|
|
|
|
|
class Mask2FormerTransformerEncoder(DeformableDetrTransformerEncoder): |
|
"""Encoder in PixelDecoder of Mask2Former.""" |
|
|
|
def forward(self, query: Tensor, query_pos: Tensor, |
|
key_padding_mask: Tensor, spatial_shapes: Tensor, |
|
level_start_index: Tensor, valid_ratios: Tensor, |
|
reference_points: Tensor, **kwargs) -> Tensor: |
|
"""Forward function of Transformer encoder. |
|
|
|
Args: |
|
query (Tensor): The input query, has shape (bs, num_queries, dim). |
|
query_pos (Tensor): The positional encoding for query, has shape |
|
(bs, num_queries, dim). If not None, it will be added to the |
|
`query` before forward function. Defaults to None. |
|
key_padding_mask (Tensor): The `key_padding_mask` of `self_attn` |
|
input. ByteTensor, has shape (bs, num_queries). |
|
spatial_shapes (Tensor): Spatial shapes of features in all levels, |
|
has shape (num_levels, 2), last dimension represents (h, w). |
|
level_start_index (Tensor): The start index of each level. |
|
A tensor has shape (num_levels, ) and can be represented |
|
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. |
|
valid_ratios (Tensor): The ratios of the valid width and the valid |
|
height relative to the width and the height of features in all |
|
levels, has shape (bs, num_levels, 2). |
|
reference_points (Tensor): The initial reference, has shape |
|
(bs, num_queries, 2) with the last dimension arranged |
|
as (cx, cy). |
|
|
|
Returns: |
|
Tensor: Output queries of Transformer encoder, which is also |
|
called 'encoder output embeddings' or 'memory', has shape |
|
(bs, num_queries, dim) |
|
""" |
|
for layer in self.layers: |
|
query = layer( |
|
query=query, |
|
query_pos=query_pos, |
|
key_padding_mask=key_padding_mask, |
|
spatial_shapes=spatial_shapes, |
|
level_start_index=level_start_index, |
|
valid_ratios=valid_ratios, |
|
reference_points=reference_points, |
|
**kwargs) |
|
return query |
|
|
|
|
|
class Mask2FormerTransformerDecoder(DetrTransformerDecoder): |
|
"""Decoder of Mask2Former.""" |
|
|
|
def _init_layers(self) -> None: |
|
"""Initialize decoder layers.""" |
|
self.layers = ModuleList([ |
|
Mask2FormerTransformerDecoderLayer(**self.layer_cfg) |
|
for _ in range(self.num_layers) |
|
]) |
|
self.embed_dims = self.layers[0].embed_dims |
|
self.post_norm = build_norm_layer(self.post_norm_cfg, |
|
self.embed_dims)[1] |
|
|
|
|
|
class Mask2FormerTransformerDecoderLayer(DetrTransformerDecoderLayer): |
|
"""Implements decoder layer in Mask2Former transformer.""" |
|
|
|
def forward(self, |
|
query: Tensor, |
|
key: Tensor = None, |
|
value: Tensor = None, |
|
query_pos: Tensor = None, |
|
key_pos: Tensor = None, |
|
self_attn_mask: Tensor = None, |
|
cross_attn_mask: Tensor = None, |
|
key_padding_mask: Tensor = None, |
|
**kwargs) -> Tensor: |
|
""" |
|
Args: |
|
query (Tensor): The input query, has shape (bs, num_queries, dim). |
|
key (Tensor, optional): The input key, has shape (bs, num_keys, |
|
dim). If `None`, the `query` will be used. Defaults to `None`. |
|
value (Tensor, optional): The input value, has the same shape as |
|
`key`, as in `nn.MultiheadAttention.forward`. If `None`, the |
|
`key` will be used. Defaults to `None`. |
|
query_pos (Tensor, optional): The positional encoding for `query`, |
|
has the same shape as `query`. If not `None`, it will be added |
|
to `query` before forward function. Defaults to `None`. |
|
key_pos (Tensor, optional): The positional encoding for `key`, has |
|
the same shape as `key`. If not `None`, it will be added to |
|
`key` before forward function. If None, and `query_pos` has the |
|
same shape as `key`, then `query_pos` will be used for |
|
`key_pos`. Defaults to None. |
|
self_attn_mask (Tensor, optional): ByteTensor mask, has shape |
|
(num_queries, num_keys), as in `nn.MultiheadAttention.forward`. |
|
Defaults to None. |
|
cross_attn_mask (Tensor, optional): ByteTensor mask, has shape |
|
(num_queries, num_keys), as in `nn.MultiheadAttention.forward`. |
|
Defaults to None. |
|
key_padding_mask (Tensor, optional): The `key_padding_mask` of |
|
`self_attn` input. ByteTensor, has shape (bs, num_value). |
|
Defaults to None. |
|
|
|
Returns: |
|
Tensor: forwarded results, has shape (bs, num_queries, dim). |
|
""" |
|
|
|
query = self.cross_attn( |
|
query=query, |
|
key=key, |
|
value=value, |
|
query_pos=query_pos, |
|
key_pos=key_pos, |
|
attn_mask=cross_attn_mask, |
|
key_padding_mask=key_padding_mask, |
|
**kwargs) |
|
query = self.norms[0](query) |
|
query = self.self_attn( |
|
query=query, |
|
key=query, |
|
value=query, |
|
query_pos=query_pos, |
|
key_pos=query_pos, |
|
attn_mask=self_attn_mask, |
|
**kwargs) |
|
query = self.norms[1](query) |
|
query = self.ffn(query) |
|
query = self.norms[2](query) |
|
|
|
return query |
|
|