TTP / mmdet /models /layers /transformer /grounding_dino_layers.py
KyanChen's picture
Upload 1861 files
3b96cb1
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
from mmcv.ops import MultiScaleDeformableAttention
from mmengine.model import ModuleList
from torch import Tensor
from mmdet.models.utils.vlfuse_helper import SingleScaleBiAttentionBlock
from mmdet.utils import ConfigType, OptConfigType
from .deformable_detr_layers import (DeformableDetrTransformerDecoderLayer,
DeformableDetrTransformerEncoder,
DeformableDetrTransformerEncoderLayer)
from .detr_layers import DetrTransformerEncoderLayer
from .dino_layers import DinoTransformerDecoder
from .utils import MLP, get_text_sine_pos_embed
try:
from fairscale.nn.checkpoint import checkpoint_wrapper
except Exception:
checkpoint_wrapper = None
class GroundingDinoTransformerDecoderLayer(
DeformableDetrTransformerDecoderLayer):
def __init__(self,
cross_attn_text_cfg: OptConfigType = dict(
embed_dims=256,
num_heads=8,
dropout=0.0,
batch_first=True),
**kwargs) -> None:
"""Decoder layer of Deformable DETR."""
self.cross_attn_text_cfg = cross_attn_text_cfg
if 'batch_first' not in self.cross_attn_text_cfg:
self.cross_attn_text_cfg['batch_first'] = True
super().__init__(**kwargs)
def _init_layers(self) -> None:
"""Initialize self_attn, cross-attn, ffn, and norms."""
self.self_attn = MultiheadAttention(**self.self_attn_cfg)
self.cross_attn_text = MultiheadAttention(**self.cross_attn_text_cfg)
self.cross_attn = MultiScaleDeformableAttention(**self.cross_attn_cfg)
self.embed_dims = self.self_attn.embed_dims
self.ffn = FFN(**self.ffn_cfg)
norms_list = [
build_norm_layer(self.norm_cfg, self.embed_dims)[1]
for _ in range(4)
]
self.norms = ModuleList(norms_list)
def forward(self,
query: Tensor,
key: Tensor = None,
value: Tensor = None,
query_pos: Tensor = None,
key_pos: Tensor = None,
self_attn_mask: Tensor = None,
cross_attn_mask: Tensor = None,
key_padding_mask: Tensor = None,
memory_text: Tensor = None,
text_attention_mask: Tensor = None,
**kwargs) -> Tensor:
"""Implements decoder layer in Grounding DINO transformer.
Args:
query (Tensor): The input query, has shape (bs, num_queries, dim).
key (Tensor, optional): The input key, has shape (bs, num_keys,
dim). If `None`, the `query` will be used. Defaults to `None`.
value (Tensor, optional): The input value, has the same shape as
`key`, as in `nn.MultiheadAttention.forward`. If `None`, the
`key` will be used. Defaults to `None`.
query_pos (Tensor, optional): The positional encoding for `query`,
has the same shape as `query`. If not `None`, it will be added
to `query` before forward function. Defaults to `None`.
key_pos (Tensor, optional): The positional encoding for `key`, has
the same shape as `key`. If not `None`, it will be added to
`key` before forward function. If None, and `query_pos` has the
same shape as `key`, then `query_pos` will be used for
`key_pos`. Defaults to None.
self_attn_mask (Tensor, optional): ByteTensor mask, has shape
(num_queries, num_keys), as in `nn.MultiheadAttention.forward`.
Defaults to None.
cross_attn_mask (Tensor, optional): ByteTensor mask, has shape
(num_queries, num_keys), as in `nn.MultiheadAttention.forward`.
Defaults to None.
key_padding_mask (Tensor, optional): The `key_padding_mask` of
`self_attn` input. ByteTensor, has shape (bs, num_value).
Defaults to None.
memory_text (Tensor): Memory text. It has shape (bs, len_text,
text_embed_dims).
text_attention_mask (Tensor): Text token mask. It has shape (bs,
len_text).
Returns:
Tensor: forwarded results, has shape (bs, num_queries, dim).
"""
# self attention
query = self.self_attn(
query=query,
key=query,
value=query,
query_pos=query_pos,
key_pos=query_pos,
attn_mask=self_attn_mask,
**kwargs)
query = self.norms[0](query)
# cross attention between query and text
query = self.cross_attn_text(
query=query,
query_pos=query_pos,
key=memory_text,
value=memory_text,
key_padding_mask=text_attention_mask)
query = self.norms[1](query)
# cross attention between query and image
query = self.cross_attn(
query=query,
key=key,
value=value,
query_pos=query_pos,
key_pos=key_pos,
attn_mask=cross_attn_mask,
key_padding_mask=key_padding_mask,
**kwargs)
query = self.norms[2](query)
query = self.ffn(query)
query = self.norms[3](query)
return query
class GroundingDinoTransformerEncoder(DeformableDetrTransformerEncoder):
def __init__(self, text_layer_cfg: ConfigType,
fusion_layer_cfg: ConfigType, **kwargs) -> None:
self.text_layer_cfg = text_layer_cfg
self.fusion_layer_cfg = fusion_layer_cfg
super().__init__(**kwargs)
def _init_layers(self) -> None:
"""Initialize encoder layers."""
self.layers = ModuleList([
DeformableDetrTransformerEncoderLayer(**self.layer_cfg)
for _ in range(self.num_layers)
])
self.text_layers = ModuleList([
DetrTransformerEncoderLayer(**self.text_layer_cfg)
for _ in range(self.num_layers)
])
self.fusion_layers = ModuleList([
SingleScaleBiAttentionBlock(**self.fusion_layer_cfg)
for _ in range(self.num_layers)
])
self.embed_dims = self.layers[0].embed_dims
if self.num_cp > 0:
if checkpoint_wrapper is None:
raise NotImplementedError(
'If you want to reduce GPU memory usage, \
please install fairscale by executing the \
following command: pip install fairscale.')
for i in range(self.num_cp):
self.layers[i] = checkpoint_wrapper(self.layers[i])
self.fusion_layers[i] = checkpoint_wrapper(
self.fusion_layers[i])
def forward(self,
query: Tensor,
query_pos: Tensor,
key_padding_mask: Tensor,
spatial_shapes: Tensor,
level_start_index: Tensor,
valid_ratios: Tensor,
memory_text: Tensor = None,
text_attention_mask: Tensor = None,
pos_text: Tensor = None,
text_self_attention_masks: Tensor = None,
position_ids: Tensor = None):
"""Forward function of Transformer encoder.
Args:
query (Tensor): The input query, has shape (bs, num_queries, dim).
query_pos (Tensor): The positional encoding for query, has shape
(bs, num_queries, dim).
key_padding_mask (Tensor): The `key_padding_mask` of `self_attn`
input. ByteTensor, has shape (bs, num_queries).
spatial_shapes (Tensor): Spatial shapes of features in all levels,
has shape (num_levels, 2), last dimension represents (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape (num_levels, ) and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
valid_ratios (Tensor): The ratios of the valid width and the valid
height relative to the width and the height of features in all
levels, has shape (bs, num_levels, 2).
memory_text (Tensor, optional): Memory text. It has shape (bs,
len_text, text_embed_dims).
text_attention_mask (Tensor, optional): Text token mask. It has
shape (bs,len_text).
pos_text (Tensor, optional): The positional encoding for text.
Defaults to None.
text_self_attention_masks (Tensor, optional): Text self attention
mask. Defaults to None.
position_ids (Tensor, optional): Text position ids.
Defaults to None.
"""
output = query
reference_points = self.get_encoder_reference_points(
spatial_shapes, valid_ratios, device=query.device)
if self.text_layers:
# generate pos_text
bs, n_text, _ = memory_text.shape
if pos_text is None and position_ids is None:
pos_text = (
torch.arange(n_text,
device=memory_text.device).float().unsqueeze(
0).unsqueeze(-1).repeat(bs, 1, 1))
pos_text = get_text_sine_pos_embed(
pos_text, num_pos_feats=256, exchange_xy=False)
if position_ids is not None:
pos_text = get_text_sine_pos_embed(
position_ids[..., None],
num_pos_feats=256,
exchange_xy=False)
# main process
for layer_id, layer in enumerate(self.layers):
if self.fusion_layers:
output, memory_text = self.fusion_layers[layer_id](
visual_feature=output,
lang_feature=memory_text,
attention_mask_v=key_padding_mask,
attention_mask_l=text_attention_mask,
)
if self.text_layers:
text_num_heads = self.text_layers[
layer_id].self_attn_cfg.num_heads
memory_text = self.text_layers[layer_id](
query=memory_text,
query_pos=(pos_text if pos_text is not None else None),
attn_mask=~text_self_attention_masks.repeat(
text_num_heads, 1, 1), # note we use ~ for mask here
key_padding_mask=None,
)
output = layer(
query=output,
query_pos=query_pos,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
key_padding_mask=key_padding_mask)
return output, memory_text
class GroundingDinoTransformerDecoder(DinoTransformerDecoder):
def _init_layers(self) -> None:
"""Initialize decoder layers."""
self.layers = ModuleList([
GroundingDinoTransformerDecoderLayer(**self.layer_cfg)
for _ in range(self.num_layers)
])
self.embed_dims = self.layers[0].embed_dims
if self.post_norm_cfg is not None:
raise ValueError('There is not post_norm in '
f'{self._get_name()}')
self.ref_point_head = MLP(self.embed_dims * 2, self.embed_dims,
self.embed_dims, 2)
self.norm = nn.LayerNorm(self.embed_dims)