Spaces:
Runtime error
Runtime error
File size: 11,894 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
from mmcv.ops import MultiScaleDeformableAttention
from mmengine.model import ModuleList
from torch import Tensor
from mmdet.models.utils.vlfuse_helper import SingleScaleBiAttentionBlock
from mmdet.utils import ConfigType, OptConfigType
from .deformable_detr_layers import (DeformableDetrTransformerDecoderLayer,
DeformableDetrTransformerEncoder,
DeformableDetrTransformerEncoderLayer)
from .detr_layers import DetrTransformerEncoderLayer
from .dino_layers import DinoTransformerDecoder
from .utils import MLP, get_text_sine_pos_embed
try:
from fairscale.nn.checkpoint import checkpoint_wrapper
except Exception:
checkpoint_wrapper = None
class GroundingDinoTransformerDecoderLayer(
DeformableDetrTransformerDecoderLayer):
def __init__(self,
cross_attn_text_cfg: OptConfigType = dict(
embed_dims=256,
num_heads=8,
dropout=0.0,
batch_first=True),
**kwargs) -> None:
"""Decoder layer of Deformable DETR."""
self.cross_attn_text_cfg = cross_attn_text_cfg
if 'batch_first' not in self.cross_attn_text_cfg:
self.cross_attn_text_cfg['batch_first'] = True
super().__init__(**kwargs)
def _init_layers(self) -> None:
"""Initialize self_attn, cross-attn, ffn, and norms."""
self.self_attn = MultiheadAttention(**self.self_attn_cfg)
self.cross_attn_text = MultiheadAttention(**self.cross_attn_text_cfg)
self.cross_attn = MultiScaleDeformableAttention(**self.cross_attn_cfg)
self.embed_dims = self.self_attn.embed_dims
self.ffn = FFN(**self.ffn_cfg)
norms_list = [
build_norm_layer(self.norm_cfg, self.embed_dims)[1]
for _ in range(4)
]
self.norms = ModuleList(norms_list)
def forward(self,
query: Tensor,
key: Tensor = None,
value: Tensor = None,
query_pos: Tensor = None,
key_pos: Tensor = None,
self_attn_mask: Tensor = None,
cross_attn_mask: Tensor = None,
key_padding_mask: Tensor = None,
memory_text: Tensor = None,
text_attention_mask: Tensor = None,
**kwargs) -> Tensor:
"""Implements decoder layer in Grounding DINO transformer.
Args:
query (Tensor): The input query, has shape (bs, num_queries, dim).
key (Tensor, optional): The input key, has shape (bs, num_keys,
dim). If `None`, the `query` will be used. Defaults to `None`.
value (Tensor, optional): The input value, has the same shape as
`key`, as in `nn.MultiheadAttention.forward`. If `None`, the
`key` will be used. Defaults to `None`.
query_pos (Tensor, optional): The positional encoding for `query`,
has the same shape as `query`. If not `None`, it will be added
to `query` before forward function. Defaults to `None`.
key_pos (Tensor, optional): The positional encoding for `key`, has
the same shape as `key`. If not `None`, it will be added to
`key` before forward function. If None, and `query_pos` has the
same shape as `key`, then `query_pos` will be used for
`key_pos`. Defaults to None.
self_attn_mask (Tensor, optional): ByteTensor mask, has shape
(num_queries, num_keys), as in `nn.MultiheadAttention.forward`.
Defaults to None.
cross_attn_mask (Tensor, optional): ByteTensor mask, has shape
(num_queries, num_keys), as in `nn.MultiheadAttention.forward`.
Defaults to None.
key_padding_mask (Tensor, optional): The `key_padding_mask` of
`self_attn` input. ByteTensor, has shape (bs, num_value).
Defaults to None.
memory_text (Tensor): Memory text. It has shape (bs, len_text,
text_embed_dims).
text_attention_mask (Tensor): Text token mask. It has shape (bs,
len_text).
Returns:
Tensor: forwarded results, has shape (bs, num_queries, dim).
"""
# self attention
query = self.self_attn(
query=query,
key=query,
value=query,
query_pos=query_pos,
key_pos=query_pos,
attn_mask=self_attn_mask,
**kwargs)
query = self.norms[0](query)
# cross attention between query and text
query = self.cross_attn_text(
query=query,
query_pos=query_pos,
key=memory_text,
value=memory_text,
key_padding_mask=text_attention_mask)
query = self.norms[1](query)
# cross attention between query and image
query = self.cross_attn(
query=query,
key=key,
value=value,
query_pos=query_pos,
key_pos=key_pos,
attn_mask=cross_attn_mask,
key_padding_mask=key_padding_mask,
**kwargs)
query = self.norms[2](query)
query = self.ffn(query)
query = self.norms[3](query)
return query
class GroundingDinoTransformerEncoder(DeformableDetrTransformerEncoder):
def __init__(self, text_layer_cfg: ConfigType,
fusion_layer_cfg: ConfigType, **kwargs) -> None:
self.text_layer_cfg = text_layer_cfg
self.fusion_layer_cfg = fusion_layer_cfg
super().__init__(**kwargs)
def _init_layers(self) -> None:
"""Initialize encoder layers."""
self.layers = ModuleList([
DeformableDetrTransformerEncoderLayer(**self.layer_cfg)
for _ in range(self.num_layers)
])
self.text_layers = ModuleList([
DetrTransformerEncoderLayer(**self.text_layer_cfg)
for _ in range(self.num_layers)
])
self.fusion_layers = ModuleList([
SingleScaleBiAttentionBlock(**self.fusion_layer_cfg)
for _ in range(self.num_layers)
])
self.embed_dims = self.layers[0].embed_dims
if self.num_cp > 0:
if checkpoint_wrapper is None:
raise NotImplementedError(
'If you want to reduce GPU memory usage, \
please install fairscale by executing the \
following command: pip install fairscale.')
for i in range(self.num_cp):
self.layers[i] = checkpoint_wrapper(self.layers[i])
self.fusion_layers[i] = checkpoint_wrapper(
self.fusion_layers[i])
def forward(self,
query: Tensor,
query_pos: Tensor,
key_padding_mask: Tensor,
spatial_shapes: Tensor,
level_start_index: Tensor,
valid_ratios: Tensor,
memory_text: Tensor = None,
text_attention_mask: Tensor = None,
pos_text: Tensor = None,
text_self_attention_masks: Tensor = None,
position_ids: Tensor = None):
"""Forward function of Transformer encoder.
Args:
query (Tensor): The input query, has shape (bs, num_queries, dim).
query_pos (Tensor): The positional encoding for query, has shape
(bs, num_queries, dim).
key_padding_mask (Tensor): The `key_padding_mask` of `self_attn`
input. ByteTensor, has shape (bs, num_queries).
spatial_shapes (Tensor): Spatial shapes of features in all levels,
has shape (num_levels, 2), last dimension represents (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape (num_levels, ) and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
valid_ratios (Tensor): The ratios of the valid width and the valid
height relative to the width and the height of features in all
levels, has shape (bs, num_levels, 2).
memory_text (Tensor, optional): Memory text. It has shape (bs,
len_text, text_embed_dims).
text_attention_mask (Tensor, optional): Text token mask. It has
shape (bs,len_text).
pos_text (Tensor, optional): The positional encoding for text.
Defaults to None.
text_self_attention_masks (Tensor, optional): Text self attention
mask. Defaults to None.
position_ids (Tensor, optional): Text position ids.
Defaults to None.
"""
output = query
reference_points = self.get_encoder_reference_points(
spatial_shapes, valid_ratios, device=query.device)
if self.text_layers:
# generate pos_text
bs, n_text, _ = memory_text.shape
if pos_text is None and position_ids is None:
pos_text = (
torch.arange(n_text,
device=memory_text.device).float().unsqueeze(
0).unsqueeze(-1).repeat(bs, 1, 1))
pos_text = get_text_sine_pos_embed(
pos_text, num_pos_feats=256, exchange_xy=False)
if position_ids is not None:
pos_text = get_text_sine_pos_embed(
position_ids[..., None],
num_pos_feats=256,
exchange_xy=False)
# main process
for layer_id, layer in enumerate(self.layers):
if self.fusion_layers:
output, memory_text = self.fusion_layers[layer_id](
visual_feature=output,
lang_feature=memory_text,
attention_mask_v=key_padding_mask,
attention_mask_l=text_attention_mask,
)
if self.text_layers:
text_num_heads = self.text_layers[
layer_id].self_attn_cfg.num_heads
memory_text = self.text_layers[layer_id](
query=memory_text,
query_pos=(pos_text if pos_text is not None else None),
attn_mask=~text_self_attention_masks.repeat(
text_num_heads, 1, 1), # note we use ~ for mask here
key_padding_mask=None,
)
output = layer(
query=output,
query_pos=query_pos,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
key_padding_mask=key_padding_mask)
return output, memory_text
class GroundingDinoTransformerDecoder(DinoTransformerDecoder):
def _init_layers(self) -> None:
"""Initialize decoder layers."""
self.layers = ModuleList([
GroundingDinoTransformerDecoderLayer(**self.layer_cfg)
for _ in range(self.num_layers)
])
self.embed_dims = self.layers[0].embed_dims
if self.post_norm_cfg is not None:
raise ValueError('There is not post_norm in '
f'{self._get_name()}')
self.ref_point_head = MLP(self.embed_dims * 2, self.embed_dims,
self.embed_dims, 2)
self.norm = nn.LayerNorm(self.embed_dims)
|