# Copyright (c) OpenMMLab. All rights reserved. import torch import torch.nn as nn from mmcv.cnn import build_norm_layer from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention from mmcv.ops import MultiScaleDeformableAttention from mmengine.model import ModuleList from torch import Tensor from mmdet.models.utils.vlfuse_helper import SingleScaleBiAttentionBlock from mmdet.utils import ConfigType, OptConfigType from .deformable_detr_layers import (DeformableDetrTransformerDecoderLayer, DeformableDetrTransformerEncoder, DeformableDetrTransformerEncoderLayer) from .detr_layers import DetrTransformerEncoderLayer from .dino_layers import DinoTransformerDecoder from .utils import MLP, get_text_sine_pos_embed try: from fairscale.nn.checkpoint import checkpoint_wrapper except Exception: checkpoint_wrapper = None class GroundingDinoTransformerDecoderLayer( DeformableDetrTransformerDecoderLayer): def __init__(self, cross_attn_text_cfg: OptConfigType = dict( embed_dims=256, num_heads=8, dropout=0.0, batch_first=True), **kwargs) -> None: """Decoder layer of Deformable DETR.""" self.cross_attn_text_cfg = cross_attn_text_cfg if 'batch_first' not in self.cross_attn_text_cfg: self.cross_attn_text_cfg['batch_first'] = True super().__init__(**kwargs) def _init_layers(self) -> None: """Initialize self_attn, cross-attn, ffn, and norms.""" self.self_attn = MultiheadAttention(**self.self_attn_cfg) self.cross_attn_text = MultiheadAttention(**self.cross_attn_text_cfg) self.cross_attn = MultiScaleDeformableAttention(**self.cross_attn_cfg) self.embed_dims = self.self_attn.embed_dims self.ffn = FFN(**self.ffn_cfg) norms_list = [ build_norm_layer(self.norm_cfg, self.embed_dims)[1] for _ in range(4) ] self.norms = ModuleList(norms_list) def forward(self, query: Tensor, key: Tensor = None, value: Tensor = None, query_pos: Tensor = None, key_pos: Tensor = None, self_attn_mask: Tensor = None, cross_attn_mask: Tensor = None, key_padding_mask: Tensor = None, memory_text: Tensor = None, text_attention_mask: Tensor = None, **kwargs) -> Tensor: """Implements decoder layer in Grounding DINO transformer. Args: query (Tensor): The input query, has shape (bs, num_queries, dim). key (Tensor, optional): The input key, has shape (bs, num_keys, dim). If `None`, the `query` will be used. Defaults to `None`. value (Tensor, optional): The input value, has the same shape as `key`, as in `nn.MultiheadAttention.forward`. If `None`, the `key` will be used. Defaults to `None`. query_pos (Tensor, optional): The positional encoding for `query`, has the same shape as `query`. If not `None`, it will be added to `query` before forward function. Defaults to `None`. key_pos (Tensor, optional): The positional encoding for `key`, has the same shape as `key`. If not `None`, it will be added to `key` before forward function. If None, and `query_pos` has the same shape as `key`, then `query_pos` will be used for `key_pos`. Defaults to None. self_attn_mask (Tensor, optional): ByteTensor mask, has shape (num_queries, num_keys), as in `nn.MultiheadAttention.forward`. Defaults to None. cross_attn_mask (Tensor, optional): ByteTensor mask, has shape (num_queries, num_keys), as in `nn.MultiheadAttention.forward`. Defaults to None. key_padding_mask (Tensor, optional): The `key_padding_mask` of `self_attn` input. ByteTensor, has shape (bs, num_value). Defaults to None. memory_text (Tensor): Memory text. It has shape (bs, len_text, text_embed_dims). text_attention_mask (Tensor): Text token mask. It has shape (bs, len_text). Returns: Tensor: forwarded results, has shape (bs, num_queries, dim). """ # self attention query = self.self_attn( query=query, key=query, value=query, query_pos=query_pos, key_pos=query_pos, attn_mask=self_attn_mask, **kwargs) query = self.norms[0](query) # cross attention between query and text query = self.cross_attn_text( query=query, query_pos=query_pos, key=memory_text, value=memory_text, key_padding_mask=text_attention_mask) query = self.norms[1](query) # cross attention between query and image query = self.cross_attn( query=query, key=key, value=value, query_pos=query_pos, key_pos=key_pos, attn_mask=cross_attn_mask, key_padding_mask=key_padding_mask, **kwargs) query = self.norms[2](query) query = self.ffn(query) query = self.norms[3](query) return query class GroundingDinoTransformerEncoder(DeformableDetrTransformerEncoder): def __init__(self, text_layer_cfg: ConfigType, fusion_layer_cfg: ConfigType, **kwargs) -> None: self.text_layer_cfg = text_layer_cfg self.fusion_layer_cfg = fusion_layer_cfg super().__init__(**kwargs) def _init_layers(self) -> None: """Initialize encoder layers.""" self.layers = ModuleList([ DeformableDetrTransformerEncoderLayer(**self.layer_cfg) for _ in range(self.num_layers) ]) self.text_layers = ModuleList([ DetrTransformerEncoderLayer(**self.text_layer_cfg) for _ in range(self.num_layers) ]) self.fusion_layers = ModuleList([ SingleScaleBiAttentionBlock(**self.fusion_layer_cfg) for _ in range(self.num_layers) ]) self.embed_dims = self.layers[0].embed_dims if self.num_cp > 0: if checkpoint_wrapper is None: raise NotImplementedError( 'If you want to reduce GPU memory usage, \ please install fairscale by executing the \ following command: pip install fairscale.') for i in range(self.num_cp): self.layers[i] = checkpoint_wrapper(self.layers[i]) self.fusion_layers[i] = checkpoint_wrapper( self.fusion_layers[i]) def forward(self, query: Tensor, query_pos: Tensor, key_padding_mask: Tensor, spatial_shapes: Tensor, level_start_index: Tensor, valid_ratios: Tensor, memory_text: Tensor = None, text_attention_mask: Tensor = None, pos_text: Tensor = None, text_self_attention_masks: Tensor = None, position_ids: Tensor = None): """Forward function of Transformer encoder. Args: query (Tensor): The input query, has shape (bs, num_queries, dim). query_pos (Tensor): The positional encoding for query, has shape (bs, num_queries, dim). key_padding_mask (Tensor): The `key_padding_mask` of `self_attn` input. ByteTensor, has shape (bs, num_queries). spatial_shapes (Tensor): Spatial shapes of features in all levels, has shape (num_levels, 2), last dimension represents (h, w). level_start_index (Tensor): The start index of each level. A tensor has shape (num_levels, ) and can be represented as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. valid_ratios (Tensor): The ratios of the valid width and the valid height relative to the width and the height of features in all levels, has shape (bs, num_levels, 2). memory_text (Tensor, optional): Memory text. It has shape (bs, len_text, text_embed_dims). text_attention_mask (Tensor, optional): Text token mask. It has shape (bs,len_text). pos_text (Tensor, optional): The positional encoding for text. Defaults to None. text_self_attention_masks (Tensor, optional): Text self attention mask. Defaults to None. position_ids (Tensor, optional): Text position ids. Defaults to None. """ output = query reference_points = self.get_encoder_reference_points( spatial_shapes, valid_ratios, device=query.device) if self.text_layers: # generate pos_text bs, n_text, _ = memory_text.shape if pos_text is None and position_ids is None: pos_text = ( torch.arange(n_text, device=memory_text.device).float().unsqueeze( 0).unsqueeze(-1).repeat(bs, 1, 1)) pos_text = get_text_sine_pos_embed( pos_text, num_pos_feats=256, exchange_xy=False) if position_ids is not None: pos_text = get_text_sine_pos_embed( position_ids[..., None], num_pos_feats=256, exchange_xy=False) # main process for layer_id, layer in enumerate(self.layers): if self.fusion_layers: output, memory_text = self.fusion_layers[layer_id]( visual_feature=output, lang_feature=memory_text, attention_mask_v=key_padding_mask, attention_mask_l=text_attention_mask, ) if self.text_layers: text_num_heads = self.text_layers[ layer_id].self_attn_cfg.num_heads memory_text = self.text_layers[layer_id]( query=memory_text, query_pos=(pos_text if pos_text is not None else None), attn_mask=~text_self_attention_masks.repeat( text_num_heads, 1, 1), # note we use ~ for mask here key_padding_mask=None, ) output = layer( query=output, query_pos=query_pos, reference_points=reference_points, spatial_shapes=spatial_shapes, level_start_index=level_start_index, key_padding_mask=key_padding_mask) return output, memory_text class GroundingDinoTransformerDecoder(DinoTransformerDecoder): def _init_layers(self) -> None: """Initialize decoder layers.""" self.layers = ModuleList([ GroundingDinoTransformerDecoderLayer(**self.layer_cfg) for _ in range(self.num_layers) ]) self.embed_dims = self.layers[0].embed_dims if self.post_norm_cfg is not None: raise ValueError('There is not post_norm in ' f'{self._get_name()}') self.ref_point_head = MLP(self.embed_dims * 2, self.embed_dims, self.embed_dims, 2) self.norm = nn.LayerNorm(self.embed_dims)