TTP / mmdet /models /dense_heads /atss_vlfusion_head.py
KyanChen's picture
Upload 1861 files
3b96cb1
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import math
from typing import Callable, List, Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import Scale
from mmcv.ops.modulated_deform_conv import ModulatedDeformConv2d
from mmengine.config import ConfigDict
from mmengine.model import BaseModel
from mmengine.structures import InstanceData
from torch import Tensor
try:
from transformers import BertConfig
except ImportError:
BertConfig = None
from mmdet.registry import MODELS
from mmdet.structures.bbox import cat_boxes
from mmdet.utils import InstanceList, OptInstanceList, reduce_mean
from ..utils import (BertEncoderLayer, VLFuse, filter_scores_and_topk,
permute_and_flatten, select_single_mlvl,
unpack_gt_instances)
from ..utils.vlfuse_helper import MAX_CLAMP_VALUE
from .atss_head import ATSSHead
def convert_grounding_to_cls_scores(logits: Tensor,
positive_maps: List[dict]) -> Tensor:
"""Convert logits to class scores."""
assert len(positive_maps) == logits.shape[0] # batch size
scores = torch.zeros(logits.shape[0], logits.shape[1],
len(positive_maps[0])).to(logits.device)
if positive_maps is not None:
if all(x == positive_maps[0] for x in positive_maps):
# only need to compute once
positive_map = positive_maps[0]
for label_j in positive_map:
scores[:, :, label_j -
1] = logits[:, :,
torch.LongTensor(positive_map[label_j]
)].mean(-1)
else:
for i, positive_map in enumerate(positive_maps):
for label_j in positive_map:
scores[i, :, label_j - 1] = logits[
i, :, torch.LongTensor(positive_map[label_j])].mean(-1)
return scores
class Conv3x3Norm(nn.Module):
"""Conv3x3 and norm."""
def __init__(self,
in_channels: int,
out_channels: int,
stride: int,
groups: int = 1,
use_dcn: bool = False,
norm_type: Optional[Union[Sequence, str]] = None):
super().__init__()
if use_dcn:
self.conv = ModulatedDeformConv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
groups=groups)
else:
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
groups=groups)
if isinstance(norm_type, Sequence):
assert len(norm_type) == 2
assert norm_type[0] == 'gn'
gn_group = norm_type[1]
norm_type = norm_type[0]
if norm_type == 'bn':
bn_op = nn.BatchNorm2d(out_channels)
elif norm_type == 'gn':
bn_op = nn.GroupNorm(
num_groups=gn_group, num_channels=out_channels)
if norm_type is not None:
self.bn = bn_op
else:
self.bn = None
def forward(self, x, **kwargs):
x = self.conv(x, **kwargs)
if self.bn:
x = self.bn(x)
return x
class DyReLU(nn.Module):
"""Dynamic ReLU."""
def __init__(self,
in_channels: int,
out_channels: int,
expand_ratio: int = 4):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.expand_ratio = expand_ratio
self.out_channels = out_channels
self.fc = nn.Sequential(
nn.Linear(in_channels, in_channels // expand_ratio),
nn.ReLU(inplace=True),
nn.Linear(in_channels // expand_ratio,
out_channels * self.expand_ratio),
nn.Hardsigmoid(inplace=True))
def forward(self, x) -> Tensor:
x_out = x
b, c, h, w = x.size()
x = self.avg_pool(x).view(b, c)
x = self.fc(x).view(b, -1, 1, 1)
a1, b1, a2, b2 = torch.split(x, self.out_channels, dim=1)
a1 = (a1 - 0.5) * 2 + 1.0
a2 = (a2 - 0.5) * 2
b1 = b1 - 0.5
b2 = b2 - 0.5
out = torch.max(x_out * a1 + b1, x_out * a2 + b2)
return out
class DyConv(nn.Module):
"""Dynamic Convolution."""
def __init__(self,
conv_func: Callable,
in_channels: int,
out_channels: int,
use_dyfuse: bool = True,
use_dyrelu: bool = False,
use_dcn: bool = False):
super().__init__()
self.dyconvs = nn.ModuleList()
self.dyconvs.append(conv_func(in_channels, out_channels, 1))
self.dyconvs.append(conv_func(in_channels, out_channels, 1))
self.dyconvs.append(conv_func(in_channels, out_channels, 2))
if use_dyfuse:
self.attnconv = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, 1, kernel_size=1),
nn.ReLU(inplace=True))
self.h_sigmoid = nn.Hardsigmoid(inplace=True)
else:
self.attnconv = None
if use_dyrelu:
self.relu = DyReLU(in_channels, out_channels)
else:
self.relu = nn.ReLU()
if use_dcn:
self.offset = nn.Conv2d(
in_channels, 27, kernel_size=3, stride=1, padding=1)
else:
self.offset = None
self.init_weights()
def init_weights(self):
for m in self.dyconvs.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight.data, 0, 0.01)
if m.bias is not None:
m.bias.data.zero_()
if self.attnconv is not None:
for m in self.attnconv.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight.data, 0, 0.01)
if m.bias is not None:
m.bias.data.zero_()
def forward(self, inputs: dict) -> dict:
visual_feats = inputs['visual']
out_vis_feats = []
for level, feature in enumerate(visual_feats):
offset_conv_args = {}
if self.offset is not None:
offset_mask = self.offset(feature)
offset = offset_mask[:, :18, :, :]
mask = offset_mask[:, 18:, :, :].sigmoid()
offset_conv_args = dict(offset=offset, mask=mask)
temp_feats = [self.dyconvs[1](feature, **offset_conv_args)]
if level > 0:
temp_feats.append(self.dyconvs[2](visual_feats[level - 1],
**offset_conv_args))
if level < len(visual_feats) - 1:
temp_feats.append(
F.upsample_bilinear(
self.dyconvs[0](visual_feats[level + 1],
**offset_conv_args),
size=[feature.size(2),
feature.size(3)]))
mean_feats = torch.mean(
torch.stack(temp_feats), dim=0, keepdim=False)
if self.attnconv is not None:
attn_feat = []
res_feat = []
for feat in temp_feats:
res_feat.append(feat)
attn_feat.append(self.attnconv(feat))
res_feat = torch.stack(res_feat)
spa_pyr_attn = self.h_sigmoid(torch.stack(attn_feat))
mean_feats = torch.mean(
res_feat * spa_pyr_attn, dim=0, keepdim=False)
out_vis_feats.append(mean_feats)
out_vis_feats = [self.relu(item) for item in out_vis_feats]
features_dict = {'visual': out_vis_feats, 'lang': inputs['lang']}
return features_dict
class VLFusionModule(BaseModel):
"""Visual-lang Fusion Module."""
def __init__(self,
in_channels: int,
feat_channels: int,
num_base_priors: int,
early_fuse: bool = False,
num_dyhead_blocks: int = 6,
lang_model_name: str = 'bert-base-uncased',
use_dyrelu: bool = True,
use_dyfuse: bool = True,
use_dcn: bool = True,
use_checkpoint: bool = False,
**kwargs) -> None:
super().__init__(**kwargs)
if BertConfig is None:
raise RuntimeError(
'transformers is not installed, please install it by: '
'pip install transformers.')
self.in_channels = in_channels
self.feat_channels = feat_channels
self.num_base_priors = num_base_priors
self.early_fuse = early_fuse
self.num_dyhead_blocks = num_dyhead_blocks
self.use_dyrelu = use_dyrelu
self.use_dyfuse = use_dyfuse
self.use_dcn = use_dcn
self.use_checkpoint = use_checkpoint
self.lang_cfg = BertConfig.from_pretrained(lang_model_name)
self.lang_dim = self.lang_cfg.hidden_size
self._init_layers()
def _init_layers(self) -> None:
"""Initialize layers of the model."""
bias_value = -math.log((1 - 0.01) / 0.01)
dyhead_tower = []
for i in range(self.num_dyhead_blocks):
if self.early_fuse:
# cross-modality fusion
dyhead_tower.append(VLFuse(use_checkpoint=self.use_checkpoint))
# lang branch
dyhead_tower.append(
BertEncoderLayer(
self.lang_cfg,
clamp_min_for_underflow=True,
clamp_max_for_overflow=True))
# vision branch
dyhead_tower.append(
DyConv(
lambda i, o, s: Conv3x3Norm(
i, o, s, use_dcn=self.use_dcn, norm_type=['gn', 16]),
self.in_channels if i == 0 else self.feat_channels,
self.feat_channels,
use_dyrelu=(self.use_dyrelu
and self.in_channels == self.feat_channels)
if i == 0 else self.use_dyrelu,
use_dyfuse=(self.use_dyfuse
and self.in_channels == self.feat_channels)
if i == 0 else self.use_dyfuse,
use_dcn=(self.use_dcn
and self.in_channels == self.feat_channels)
if i == 0 else self.use_dcn,
))
self.add_module('dyhead_tower', nn.Sequential(*dyhead_tower))
self.bbox_pred = nn.Conv2d(
self.feat_channels, self.num_base_priors * 4, kernel_size=1)
self.centerness = nn.Conv2d(
self.feat_channels, self.num_base_priors * 1, kernel_size=1)
self.dot_product_projection_text = nn.Linear(
self.lang_dim,
self.num_base_priors * self.feat_channels,
bias=True)
self.log_scale = nn.Parameter(torch.Tensor([0.0]), requires_grad=True)
self.bias_lang = nn.Parameter(
torch.zeros(self.lang_dim), requires_grad=True)
self.bias0 = nn.Parameter(
torch.Tensor([bias_value]), requires_grad=True)
self.scales = nn.ModuleList([Scale(1.0) for _ in range(5)])
def forward(self, visual_feats: Tuple[Tensor],
language_feats: dict) -> Tuple:
feat_inputs = {'visual': visual_feats, 'lang': language_feats}
dyhead_tower = self.dyhead_tower(feat_inputs)
if self.early_fuse:
embedding = dyhead_tower['lang']['hidden']
else:
embedding = language_feats['embedded']
embedding = F.normalize(embedding, p=2, dim=-1)
dot_product_proj_tokens = self.dot_product_projection_text(embedding /
2.0)
dot_product_proj_tokens_bias = torch.matmul(
embedding, self.bias_lang) + self.bias0
bbox_preds = []
centerness = []
cls_logits = []
for i, feature in enumerate(visual_feats):
visual = dyhead_tower['visual'][i]
B, C, H, W = visual.shape
bbox_pred = self.scales[i](self.bbox_pred(visual))
bbox_preds.append(bbox_pred)
centerness.append(self.centerness(visual))
dot_product_proj_queries = permute_and_flatten(
visual, B, self.num_base_priors, C, H, W)
bias = dot_product_proj_tokens_bias.unsqueeze(1).repeat(
1, self.num_base_priors, 1)
dot_product_logit = (
torch.matmul(dot_product_proj_queries,
dot_product_proj_tokens.transpose(-1, -2)) /
self.log_scale.exp()) + bias
dot_product_logit = torch.clamp(
dot_product_logit, max=MAX_CLAMP_VALUE)
dot_product_logit = torch.clamp(
dot_product_logit, min=-MAX_CLAMP_VALUE)
cls_logits.append(dot_product_logit)
return bbox_preds, centerness, cls_logits
@MODELS.register_module()
class ATSSVLFusionHead(ATSSHead):
"""ATSS head with visual-language fusion module.
Args:
early_fuse (bool): Whether to fuse visual and language features
Defaults to False.
use_checkpoint (bool): Whether to use checkpoint. Defaults to False.
num_dyhead_blocks (int): Number of dynamic head blocks. Defaults to 6.
lang_model_name (str): Name of the language model.
Defaults to 'bert-base-uncased'.
"""
def __init__(self,
*args,
early_fuse: bool = False,
use_checkpoint: bool = False,
num_dyhead_blocks: int = 6,
lang_model_name: str = 'bert-base-uncased',
init_cfg=None,
**kwargs):
super().__init__(*args, **kwargs, init_cfg=init_cfg)
self.head = VLFusionModule(
in_channels=self.in_channels,
feat_channels=self.feat_channels,
num_base_priors=self.num_base_priors,
early_fuse=early_fuse,
use_checkpoint=use_checkpoint,
num_dyhead_blocks=num_dyhead_blocks,
lang_model_name=lang_model_name)
self.text_masks = None
def _init_layers(self) -> None:
"""No need to initialize the ATSS head layer."""
pass
def forward(self, visual_feats: Tuple[Tensor],
language_feats: dict) -> Tuple[Tensor]:
"""Forward function."""
bbox_preds, centerness, cls_logits = self.head(visual_feats,
language_feats)
return cls_logits, bbox_preds, centerness
def loss(self, visual_feats: Tuple[Tensor], language_feats: dict,
batch_data_samples):
outputs = unpack_gt_instances(batch_data_samples)
(batch_gt_instances, batch_gt_instances_ignore,
batch_img_metas) = outputs
outs = self(visual_feats, language_feats)
self.text_masks = language_feats['masks']
loss_inputs = outs + (batch_gt_instances, batch_img_metas,
batch_gt_instances_ignore)
losses = self.loss_by_feat(*loss_inputs)
return losses
def loss_by_feat(
self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
centernesses: List[Tensor],
batch_gt_instances: InstanceList,
batch_img_metas: List[dict],
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
"""Calculate the loss based on the features extracted by the detection
head.
Args:
cls_scores (list[Tensor]): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W)
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_anchors * 4, H, W)
centernesses (list[Tensor]): Centerness for each scale
level with shape (N, num_anchors * 1, H, W)
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
featmap_sizes = [featmap.size()[-2:] for featmap in bbox_preds]
assert len(featmap_sizes) == self.prior_generator.num_levels
device = cls_scores[0].device
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, batch_img_metas, device=device)
cls_reg_targets = self.get_targets(
anchor_list,
valid_flag_list,
batch_gt_instances,
batch_img_metas,
batch_gt_instances_ignore=batch_gt_instances_ignore)
(anchor_list, labels_list, label_weights_list, bbox_targets_list,
bbox_weights_list, avg_factor) = cls_reg_targets
avg_factor = reduce_mean(
torch.tensor(avg_factor, dtype=torch.float, device=device)).item()
anchors = torch.cat(anchor_list, dim=1)
labels = torch.cat(labels_list, dim=1)
label_weights = torch.cat(label_weights_list, dim=1)
bbox_targets = torch.cat(bbox_targets_list, dim=1)
cls_scores = torch.cat(cls_scores, dim=1)
centernesses_ = []
bbox_preds_ = []
for bbox_pred, centerness in zip(bbox_preds, centernesses):
centernesses_.append(
centerness.permute(0, 2, 3,
1).reshape(cls_scores.size(0), -1, 1))
bbox_preds_.append(
bbox_pred.permute(0, 2, 3,
1).reshape(cls_scores.size(0), -1, 4))
bbox_preds = torch.cat(bbox_preds_, dim=1)
centernesses = torch.cat(centernesses_, dim=1)
losses_cls, losses_bbox, loss_centerness, bbox_avg_factor = \
self._loss_by_feat(
anchors,
cls_scores,
bbox_preds,
centernesses,
labels,
label_weights,
bbox_targets,
avg_factor=avg_factor)
bbox_avg_factor = reduce_mean(bbox_avg_factor).clamp_(min=1).item()
losses_bbox = losses_bbox / bbox_avg_factor
return dict(
loss_cls=losses_cls,
loss_bbox=losses_bbox,
loss_centerness=loss_centerness)
def _loss_by_feat(self, anchors: Tensor, cls_score: Tensor,
bbox_pred: Tensor, centerness: Tensor, labels: Tensor,
label_weights: Tensor, bbox_targets: Tensor,
avg_factor: float) -> dict:
"""Calculate the loss of all scale level based on the features
extracted by the detection head.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
anchors = anchors.reshape(-1, 4)
# ===== this change =====
pos_inds = (labels.sum(-1) > 0).reshape(-1)
# Loss is not computed for the padded regions of the text.
assert (self.text_masks.dim() == 2)
text_mask = (self.text_masks > 0).unsqueeze(1)
text_mask = text_mask.repeat(1, cls_score.size(1), 1)
cls_score = torch.masked_select(cls_score, text_mask).contiguous()
labels = torch.masked_select(labels, text_mask)
label_weights = label_weights[...,
None].repeat(1, 1, text_mask.size(-1))
label_weights = torch.masked_select(label_weights, text_mask)
bbox_pred = bbox_pred.reshape(-1, 4)
centerness = centerness.reshape(-1)
bbox_targets = bbox_targets.reshape(-1, 4)
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
# classification loss
loss_cls = self.loss_cls(
cls_score, labels, label_weights, avg_factor=avg_factor)
if pos_inds.sum() > 0:
pos_bbox_targets = bbox_targets[pos_inds]
pos_bbox_pred = bbox_pred[pos_inds]
pos_anchors = anchors[pos_inds]
pos_centerness = centerness[pos_inds]
centerness_targets = self.centerness_target(
pos_anchors, pos_bbox_targets)
if torch.isnan(centerness_targets).any():
print('=====Centerness includes NaN=====')
mask = ~torch.isnan(centerness_targets)
centerness_targets = centerness_targets[mask]
pos_centerness = pos_centerness[mask]
pos_anchors = pos_anchors[mask]
pos_bbox_targets = pos_bbox_targets[mask]
pos_bbox_pred = pos_bbox_pred[mask]
if pos_bbox_targets.shape[0] == 0:
loss_bbox = bbox_pred.sum() * 0
loss_centerness = centerness.sum() * 0
centerness_targets = bbox_targets.new_tensor(0.)
return loss_cls, loss_bbox, loss_centerness, \
centerness_targets.sum()
# The decoding process takes the offset into consideration.
pos_anchors[:, 2:] += 1
pos_decode_bbox_pred = self.bbox_coder.decode(
pos_anchors, pos_bbox_pred)
# regression loss
loss_bbox = self.loss_bbox(
pos_decode_bbox_pred,
pos_bbox_targets,
weight=centerness_targets,
avg_factor=1.0)
# centerness loss
loss_centerness = self.loss_centerness(
pos_centerness, centerness_targets, avg_factor=avg_factor)
else:
loss_bbox = bbox_pred.sum() * 0
loss_centerness = centerness.sum() * 0
centerness_targets = bbox_targets.new_tensor(0.)
return loss_cls, loss_bbox, loss_centerness, centerness_targets.sum()
def _get_targets_single(self,
flat_anchors: Tensor,
valid_flags: Tensor,
num_level_anchors: List[int],
gt_instances: InstanceData,
img_meta: dict,
gt_instances_ignore: Optional[InstanceData] = None,
unmap_outputs: bool = True) -> tuple:
"""Compute regression, classification targets for anchors in a single
image.
Args:
flat_anchors (Tensor): Multi-level anchors of the image, which are
concatenated into a single tensor of shape (num_anchors ,4)
valid_flags (Tensor): Multi level valid flags of the image,
which are concatenated into a single tensor of
shape (num_anchors,).
num_level_anchors (List[int]): Number of anchors of each scale
level.
gt_instances (:obj:`InstanceData`): Ground truth of instance
annotations. It usually includes ``bboxes`` and ``labels``
attributes.
img_meta (dict): Meta information for current image.
gt_instances_ignore (:obj:`InstanceData`, optional): Instances
to be ignored during training. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
unmap_outputs (bool): Whether to map outputs back to the original
set of anchors.
Returns:
tuple: N is the number of total anchors in the image.
labels (Tensor): Labels of all anchors in the image with shape
(N,).
label_weights (Tensor): Label weights of all anchor in the
image with shape (N,).
bbox_targets (Tensor): BBox targets of all anchors in the
image with shape (N, 4).
bbox_weights (Tensor): BBox weights of all anchors in the
image with shape (N, 4)
pos_inds (Tensor): Indices of positive anchor with shape
(num_pos,).
neg_inds (Tensor): Indices of negative anchor with shape
(num_neg,).
sampling_result (:obj:`SamplingResult`): Sampling results.
"""
anchors = flat_anchors
# Align the official implementation
anchors[:, 2:] -= 1
num_level_anchors_inside = num_level_anchors
pred_instances = InstanceData(priors=anchors)
assign_result = self.assigner.assign(pred_instances,
num_level_anchors_inside,
gt_instances, gt_instances_ignore)
sampling_result = self.sampler.sample(assign_result, pred_instances,
gt_instances)
num_valid_anchors = anchors.shape[0]
bbox_targets = torch.zeros_like(anchors)
bbox_weights = torch.zeros_like(anchors)
# ===== this change =====
labels = anchors.new_full((num_valid_anchors, self.feat_channels),
0,
dtype=torch.float32)
label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
if len(pos_inds) > 0:
if self.reg_decoded_bbox:
pos_bbox_targets = sampling_result.pos_gt_bboxes
else:
pos_bbox_targets = self.bbox_coder.encode(
sampling_result.pos_priors, sampling_result.pos_gt_bboxes)
bbox_targets[pos_inds, :] = pos_bbox_targets
bbox_weights[pos_inds, :] = 1.0
# ===== this change =====
labels[pos_inds] = gt_instances.positive_maps[
sampling_result.pos_assigned_gt_inds]
if self.train_cfg['pos_weight'] <= 0:
label_weights[pos_inds] = 1.0
else:
label_weights[pos_inds] = self.train_cfg['pos_weight']
if len(neg_inds) > 0:
label_weights[neg_inds] = 1.0
return (anchors, labels, label_weights, bbox_targets, bbox_weights,
pos_inds, neg_inds, sampling_result)
def centerness_target(self, anchors: Tensor, gts: Tensor) -> Tensor:
"""Calculate the centerness between anchors and gts.
Only calculate pos centerness targets, otherwise there may be nan.
Args:
anchors (Tensor): Anchors with shape (N, 4), "xyxy" format.
gts (Tensor): Ground truth bboxes with shape (N, 4), "xyxy" format.
Returns:
Tensor: Centerness between anchors and gts.
"""
anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2
anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2
l_ = anchors_cx - gts[:, 0]
t_ = anchors_cy - gts[:, 1]
r_ = gts[:, 2] - anchors_cx
b_ = gts[:, 3] - anchors_cy
left_right = torch.stack([l_, r_], dim=1)
top_bottom = torch.stack([t_, b_], dim=1)
centerness = torch.sqrt(
(left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) *
(top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]))
# assert not torch.isnan(centerness).any()
return centerness
def predict(self,
visual_feats: Tuple[Tensor],
language_feats: dict,
batch_data_samples,
rescale: bool = True):
"""Perform forward propagation of the detection head and predict
detection results on the features of the upstream network.
Args:
visual_feats (tuple[Tensor]): Multi-level visual features from the
upstream network, each is a 4D-tensor.
language_feats (dict): Language features from the upstream network.
batch_data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[obj:`InstanceData`]: Detection results of each image
after the post process.
"""
batch_img_metas = [
data_samples.metainfo for data_samples in batch_data_samples
]
batch_token_positive_maps = [
data_samples.token_positive_map
for data_samples in batch_data_samples
]
outs = self(visual_feats, language_feats)
predictions = self.predict_by_feat(
*outs,
batch_img_metas=batch_img_metas,
batch_token_positive_maps=batch_token_positive_maps,
rescale=rescale)
return predictions
def predict_by_feat(self,
cls_logits: List[Tensor],
bbox_preds: List[Tensor],
score_factors: List[Tensor],
batch_img_metas: Optional[List[dict]] = None,
batch_token_positive_maps: Optional[List[dict]] = None,
cfg: Optional[ConfigDict] = None,
rescale: bool = False,
with_nms: bool = True) -> InstanceList:
"""Transform a batch of output features extracted from the head into
bbox results.
Note: When score_factors is not None, the cls_scores are
usually multiplied by it then obtain the real score used in NMS,
such as CenterNess in FCOS, IoU branch in ATSS.
Args:
cls_logits (list[Tensor]): Classification scores for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * 4, H, W).
score_factors (list[Tensor], optional): Score factor for
all scale level, each is a 4D-tensor, has shape
(batch_size, num_priors * 1, H, W). Defaults to None.
batch_img_metas (list[dict], Optional): Batch image meta info.
Defaults to None.
batch_token_positive_maps (list[dict], Optional): Batch token
positive map. Defaults to None.
cfg (ConfigDict, optional): Test / postprocessing
configuration, if None, test_cfg would be used.
Defaults to None.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
with_nms (bool): If True, do nms before return boxes.
Defaults to True.
Returns:
list[:obj:`InstanceData`]: Object detection results of each image
after the post process. Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
"""
assert len(bbox_preds) == len(score_factors)
num_levels = len(bbox_preds)
featmap_sizes = [bbox_preds[i].shape[-2:] for i in range(num_levels)]
mlvl_priors = self.prior_generator.grid_priors(
featmap_sizes,
dtype=bbox_preds[0].dtype,
device=bbox_preds[0].device)
result_list = []
for img_id in range(len(batch_img_metas)):
img_meta = batch_img_metas[img_id]
token_positive_maps = batch_token_positive_maps[img_id]
bbox_pred_list = select_single_mlvl(
bbox_preds, img_id, detach=True)
score_factor_list = select_single_mlvl(
score_factors, img_id, detach=True)
cls_logit_list = select_single_mlvl(
cls_logits, img_id, detach=True)
results = self._predict_by_feat_single(
bbox_pred_list=bbox_pred_list,
score_factor_list=score_factor_list,
cls_logit_list=cls_logit_list,
mlvl_priors=mlvl_priors,
token_positive_maps=token_positive_maps,
img_meta=img_meta,
cfg=cfg,
rescale=rescale,
with_nms=with_nms)
result_list.append(results)
return result_list
def _predict_by_feat_single(self,
bbox_pred_list: List[Tensor],
score_factor_list: List[Tensor],
cls_logit_list: List[Tensor],
mlvl_priors: List[Tensor],
token_positive_maps: dict,
img_meta: dict,
cfg: ConfigDict,
rescale: bool = True,
with_nms: bool = True) -> InstanceData:
"""Transform a single image's features extracted from the head into
bbox results.
Args:
bbox_pred_list (list[Tensor]): Box energies / deltas from
all scale levels of a single image, each item has shape
(num_priors * 4, H, W).
score_factor_list (list[Tensor]): Score factor from all scale
levels of a single image, each item has shape
(num_priors * 1, H, W).
cls_logit_list (list[Tensor]): Box scores from all scale
levels of a single image, each item has shape
(num_priors * num_classes, H, W).
mlvl_priors (list[Tensor]): Each element in the list is
the priors of a single level in feature pyramid. In all
anchor-based methods, it has shape (num_priors, 4). In
all anchor-free methods, it has shape (num_priors, 2)
when `with_stride=True`, otherwise it still has shape
(num_priors, 4).
token_positive_maps (dict): Token positive map.
img_meta (dict): Image meta info.
cfg (mmengine.Config): Test / postprocessing configuration,
if None, test_cfg would be used.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
with_nms (bool): If True, do nms before return boxes.
Defaults to True.
Returns:
:obj:`InstanceData`: Detection results of each image
after the post process.
Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
"""
cfg = self.test_cfg if cfg is None else cfg
cfg = copy.deepcopy(cfg)
img_shape = img_meta['img_shape']
nms_pre = cfg.get('nms_pre', -1)
score_thr = cfg.get('score_thr', 0)
mlvl_bbox_preds = []
mlvl_valid_priors = []
mlvl_scores = []
mlvl_labels = []
for level_idx, (bbox_pred, score_factor, cls_logit, priors) in \
enumerate(zip(bbox_pred_list,
score_factor_list, cls_logit_list, mlvl_priors)):
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(
-1, self.bbox_coder.encode_size)
score_factor = score_factor.permute(1, 2, 0).reshape(-1).sigmoid()
scores = convert_grounding_to_cls_scores(
logits=cls_logit.sigmoid()[None],
positive_maps=[token_positive_maps])[0]
results = filter_scores_and_topk(
scores, score_thr, nms_pre,
dict(bbox_pred=bbox_pred, priors=priors))
scores, labels, keep_idxs, filtered_results = results
bbox_pred = filtered_results['bbox_pred']
priors = filtered_results['priors']
score_factor = score_factor[keep_idxs]
scores = torch.sqrt(scores * score_factor)
mlvl_bbox_preds.append(bbox_pred)
mlvl_valid_priors.append(priors)
mlvl_scores.append(scores)
mlvl_labels.append(labels)
bbox_pred = torch.cat(mlvl_bbox_preds)
priors = cat_boxes(mlvl_valid_priors)
bboxes = self.bbox_coder.decode(priors, bbox_pred, max_shape=img_shape)
results = InstanceData()
results.bboxes = bboxes
results.scores = torch.cat(mlvl_scores)
results.labels = torch.cat(mlvl_labels)
predictions = self._bbox_post_process(
results=results,
cfg=cfg,
rescale=rescale,
with_nms=with_nms,
img_meta=img_meta)
if len(predictions) > 0:
# Note: GLIP adopts a very strange bbox decoder logic,
# and if 1 is not added here, it will not align with
# the official mAP.
predictions.bboxes[:, 2:] = predictions.bboxes[:, 2:] + 1
return predictions