|
|
|
from typing import Dict, List, Tuple |
|
|
|
import torch |
|
from mmengine.structures import InstanceData |
|
from torch import Tensor |
|
|
|
from mmdet.registry import MODELS |
|
from mmdet.structures import SampleList |
|
from mmdet.structures.bbox import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh |
|
from mmdet.utils import InstanceList, OptInstanceList, reduce_mean |
|
from ..utils import multi_apply |
|
from .deformable_detr_head import DeformableDETRHead |
|
|
|
|
|
@MODELS.register_module() |
|
class DINOHead(DeformableDETRHead): |
|
r"""Head of the DINO: DETR with Improved DeNoising Anchor Boxes |
|
for End-to-End Object Detection |
|
|
|
Code is modified from the `official github repo |
|
<https://github.com/IDEA-Research/DINO>`_. |
|
|
|
More details can be found in the `paper |
|
<https://arxiv.org/abs/2203.03605>`_ . |
|
""" |
|
|
|
def loss(self, hidden_states: Tensor, references: List[Tensor], |
|
enc_outputs_class: Tensor, enc_outputs_coord: Tensor, |
|
batch_data_samples: SampleList, dn_meta: Dict[str, int]) -> dict: |
|
"""Perform forward propagation and loss calculation of the detection |
|
head on the queries of the upstream network. |
|
|
|
Args: |
|
hidden_states (Tensor): Hidden states output from each decoder |
|
layer, has shape (num_decoder_layers, bs, num_queries_total, |
|
dim), where `num_queries_total` is the sum of |
|
`num_denoising_queries` and `num_matching_queries` when |
|
`self.training` is `True`, else `num_matching_queries`. |
|
references (list[Tensor]): List of the reference from the decoder. |
|
The first reference is the `init_reference` (initial) and the |
|
other num_decoder_layers(6) references are `inter_references` |
|
(intermediate). The `init_reference` has shape (bs, |
|
num_queries_total, 4) and each `inter_reference` has shape |
|
(bs, num_queries, 4) with the last dimension arranged as |
|
(cx, cy, w, h). |
|
enc_outputs_class (Tensor): The score of each point on encode |
|
feature map, has shape (bs, num_feat_points, cls_out_channels). |
|
enc_outputs_coord (Tensor): The proposal generate from the |
|
encode feature map, has shape (bs, num_feat_points, 4) with the |
|
last dimension arranged as (cx, cy, w, h). |
|
batch_data_samples (list[:obj:`DetDataSample`]): The Data |
|
Samples. It usually includes information such as |
|
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. |
|
dn_meta (Dict[str, int]): The dictionary saves information about |
|
group collation, including 'num_denoising_queries' and |
|
'num_denoising_groups'. It will be used for split outputs of |
|
denoising and matching parts and loss calculation. |
|
|
|
Returns: |
|
dict: A dictionary of loss components. |
|
""" |
|
batch_gt_instances = [] |
|
batch_img_metas = [] |
|
for data_sample in batch_data_samples: |
|
batch_img_metas.append(data_sample.metainfo) |
|
batch_gt_instances.append(data_sample.gt_instances) |
|
|
|
outs = self(hidden_states, references) |
|
loss_inputs = outs + (enc_outputs_class, enc_outputs_coord, |
|
batch_gt_instances, batch_img_metas, dn_meta) |
|
losses = self.loss_by_feat(*loss_inputs) |
|
return losses |
|
|
|
def loss_by_feat( |
|
self, |
|
all_layers_cls_scores: Tensor, |
|
all_layers_bbox_preds: Tensor, |
|
enc_cls_scores: Tensor, |
|
enc_bbox_preds: Tensor, |
|
batch_gt_instances: InstanceList, |
|
batch_img_metas: List[dict], |
|
dn_meta: Dict[str, int], |
|
batch_gt_instances_ignore: OptInstanceList = None |
|
) -> Dict[str, Tensor]: |
|
"""Loss function. |
|
|
|
Args: |
|
all_layers_cls_scores (Tensor): Classification scores of all |
|
decoder layers, has shape (num_decoder_layers, bs, |
|
num_queries_total, cls_out_channels), where |
|
`num_queries_total` is the sum of `num_denoising_queries` |
|
and `num_matching_queries`. |
|
all_layers_bbox_preds (Tensor): Regression outputs of all decoder |
|
layers. Each is a 4D-tensor with normalized coordinate format |
|
(cx, cy, w, h) and has shape (num_decoder_layers, bs, |
|
num_queries_total, 4). |
|
enc_cls_scores (Tensor): The score of each point on encode |
|
feature map, has shape (bs, num_feat_points, cls_out_channels). |
|
enc_bbox_preds (Tensor): The proposal generate from the encode |
|
feature map, has shape (bs, num_feat_points, 4) with the last |
|
dimension arranged as (cx, cy, w, h). |
|
batch_gt_instances (list[:obj:`InstanceData`]): Batch of |
|
gt_instance. It usually includes ``bboxes`` and ``labels`` |
|
attributes. |
|
batch_img_metas (list[dict]): Meta information of each image, e.g., |
|
image size, scaling factor, etc. |
|
dn_meta (Dict[str, int]): The dictionary saves information about |
|
group collation, including 'num_denoising_queries' and |
|
'num_denoising_groups'. It will be used for split outputs of |
|
denoising and matching parts and loss calculation. |
|
batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): |
|
Batch of gt_instances_ignore. It includes ``bboxes`` attribute |
|
data that is ignored during training and testing. |
|
Defaults to None. |
|
|
|
Returns: |
|
dict[str, Tensor]: A dictionary of loss components. |
|
""" |
|
|
|
(all_layers_matching_cls_scores, all_layers_matching_bbox_preds, |
|
all_layers_denoising_cls_scores, all_layers_denoising_bbox_preds) = \ |
|
self.split_outputs( |
|
all_layers_cls_scores, all_layers_bbox_preds, dn_meta) |
|
|
|
loss_dict = super(DeformableDETRHead, self).loss_by_feat( |
|
all_layers_matching_cls_scores, all_layers_matching_bbox_preds, |
|
batch_gt_instances, batch_img_metas, batch_gt_instances_ignore) |
|
|
|
|
|
|
|
|
|
|
|
if enc_cls_scores is not None: |
|
|
|
|
|
enc_loss_cls, enc_losses_bbox, enc_losses_iou = \ |
|
self.loss_by_feat_single( |
|
enc_cls_scores, enc_bbox_preds, |
|
batch_gt_instances=batch_gt_instances, |
|
batch_img_metas=batch_img_metas) |
|
loss_dict['enc_loss_cls'] = enc_loss_cls |
|
loss_dict['enc_loss_bbox'] = enc_losses_bbox |
|
loss_dict['enc_loss_iou'] = enc_losses_iou |
|
|
|
if all_layers_denoising_cls_scores is not None: |
|
|
|
dn_losses_cls, dn_losses_bbox, dn_losses_iou = self.loss_dn( |
|
all_layers_denoising_cls_scores, |
|
all_layers_denoising_bbox_preds, |
|
batch_gt_instances=batch_gt_instances, |
|
batch_img_metas=batch_img_metas, |
|
dn_meta=dn_meta) |
|
|
|
loss_dict['dn_loss_cls'] = dn_losses_cls[-1] |
|
loss_dict['dn_loss_bbox'] = dn_losses_bbox[-1] |
|
loss_dict['dn_loss_iou'] = dn_losses_iou[-1] |
|
for num_dec_layer, (loss_cls_i, loss_bbox_i, loss_iou_i) in \ |
|
enumerate(zip(dn_losses_cls[:-1], dn_losses_bbox[:-1], |
|
dn_losses_iou[:-1])): |
|
loss_dict[f'd{num_dec_layer}.dn_loss_cls'] = loss_cls_i |
|
loss_dict[f'd{num_dec_layer}.dn_loss_bbox'] = loss_bbox_i |
|
loss_dict[f'd{num_dec_layer}.dn_loss_iou'] = loss_iou_i |
|
return loss_dict |
|
|
|
def loss_dn(self, all_layers_denoising_cls_scores: Tensor, |
|
all_layers_denoising_bbox_preds: Tensor, |
|
batch_gt_instances: InstanceList, batch_img_metas: List[dict], |
|
dn_meta: Dict[str, int]) -> Tuple[List[Tensor]]: |
|
"""Calculate denoising loss. |
|
|
|
Args: |
|
all_layers_denoising_cls_scores (Tensor): Classification scores of |
|
all decoder layers in denoising part, has shape ( |
|
num_decoder_layers, bs, num_denoising_queries, |
|
cls_out_channels). |
|
all_layers_denoising_bbox_preds (Tensor): Regression outputs of all |
|
decoder layers in denoising part. Each is a 4D-tensor with |
|
normalized coordinate format (cx, cy, w, h) and has shape |
|
(num_decoder_layers, bs, num_denoising_queries, 4). |
|
batch_gt_instances (list[:obj:`InstanceData`]): Batch of |
|
gt_instance. It usually includes ``bboxes`` and ``labels`` |
|
attributes. |
|
batch_img_metas (list[dict]): Meta information of each image, e.g., |
|
image size, scaling factor, etc. |
|
dn_meta (Dict[str, int]): The dictionary saves information about |
|
group collation, including 'num_denoising_queries' and |
|
'num_denoising_groups'. It will be used for split outputs of |
|
denoising and matching parts and loss calculation. |
|
|
|
Returns: |
|
Tuple[List[Tensor]]: The loss_dn_cls, loss_dn_bbox, and loss_dn_iou |
|
of each decoder layers. |
|
""" |
|
return multi_apply( |
|
self._loss_dn_single, |
|
all_layers_denoising_cls_scores, |
|
all_layers_denoising_bbox_preds, |
|
batch_gt_instances=batch_gt_instances, |
|
batch_img_metas=batch_img_metas, |
|
dn_meta=dn_meta) |
|
|
|
def _loss_dn_single(self, dn_cls_scores: Tensor, dn_bbox_preds: Tensor, |
|
batch_gt_instances: InstanceList, |
|
batch_img_metas: List[dict], |
|
dn_meta: Dict[str, int]) -> Tuple[Tensor]: |
|
"""Denoising loss for outputs from a single decoder layer. |
|
|
|
Args: |
|
dn_cls_scores (Tensor): Classification scores of a single decoder |
|
layer in denoising part, has shape (bs, num_denoising_queries, |
|
cls_out_channels). |
|
dn_bbox_preds (Tensor): Regression outputs of a single decoder |
|
layer in denoising part. Each is a 4D-tensor with normalized |
|
coordinate format (cx, cy, w, h) and has shape |
|
(bs, num_denoising_queries, 4). |
|
batch_gt_instances (list[:obj:`InstanceData`]): Batch of |
|
gt_instance. It usually includes ``bboxes`` and ``labels`` |
|
attributes. |
|
batch_img_metas (list[dict]): Meta information of each image, e.g., |
|
image size, scaling factor, etc. |
|
dn_meta (Dict[str, int]): The dictionary saves information about |
|
group collation, including 'num_denoising_queries' and |
|
'num_denoising_groups'. It will be used for split outputs of |
|
denoising and matching parts and loss calculation. |
|
|
|
Returns: |
|
Tuple[Tensor]: A tuple including `loss_cls`, `loss_box` and |
|
`loss_iou`. |
|
""" |
|
cls_reg_targets = self.get_dn_targets(batch_gt_instances, |
|
batch_img_metas, dn_meta) |
|
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, |
|
num_total_pos, num_total_neg) = cls_reg_targets |
|
labels = torch.cat(labels_list, 0) |
|
label_weights = torch.cat(label_weights_list, 0) |
|
bbox_targets = torch.cat(bbox_targets_list, 0) |
|
bbox_weights = torch.cat(bbox_weights_list, 0) |
|
|
|
|
|
cls_scores = dn_cls_scores.reshape(-1, self.cls_out_channels) |
|
|
|
cls_avg_factor = \ |
|
num_total_pos * 1.0 + num_total_neg * self.bg_cls_weight |
|
if self.sync_cls_avg_factor: |
|
cls_avg_factor = reduce_mean( |
|
cls_scores.new_tensor([cls_avg_factor])) |
|
cls_avg_factor = max(cls_avg_factor, 1) |
|
|
|
if len(cls_scores) > 0: |
|
loss_cls = self.loss_cls( |
|
cls_scores, labels, label_weights, avg_factor=cls_avg_factor) |
|
else: |
|
loss_cls = torch.zeros( |
|
1, dtype=cls_scores.dtype, device=cls_scores.device) |
|
|
|
|
|
|
|
num_total_pos = loss_cls.new_tensor([num_total_pos]) |
|
num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item() |
|
|
|
|
|
factors = [] |
|
for img_meta, bbox_pred in zip(batch_img_metas, dn_bbox_preds): |
|
img_h, img_w = img_meta['img_shape'] |
|
factor = bbox_pred.new_tensor([img_w, img_h, img_w, |
|
img_h]).unsqueeze(0).repeat( |
|
bbox_pred.size(0), 1) |
|
factors.append(factor) |
|
factors = torch.cat(factors) |
|
|
|
|
|
|
|
|
|
bbox_preds = dn_bbox_preds.reshape(-1, 4) |
|
bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors |
|
bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors |
|
|
|
|
|
loss_iou = self.loss_iou( |
|
bboxes, bboxes_gt, bbox_weights, avg_factor=num_total_pos) |
|
|
|
|
|
loss_bbox = self.loss_bbox( |
|
bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos) |
|
return loss_cls, loss_bbox, loss_iou |
|
|
|
def get_dn_targets(self, batch_gt_instances: InstanceList, |
|
batch_img_metas: dict, dn_meta: Dict[str, |
|
int]) -> tuple: |
|
"""Get targets in denoising part for a batch of images. |
|
|
|
Args: |
|
batch_gt_instances (list[:obj:`InstanceData`]): Batch of |
|
gt_instance. It usually includes ``bboxes`` and ``labels`` |
|
attributes. |
|
batch_img_metas (list[dict]): Meta information of each image, e.g., |
|
image size, scaling factor, etc. |
|
dn_meta (Dict[str, int]): The dictionary saves information about |
|
group collation, including 'num_denoising_queries' and |
|
'num_denoising_groups'. It will be used for split outputs of |
|
denoising and matching parts and loss calculation. |
|
|
|
Returns: |
|
tuple: a tuple containing the following targets. |
|
|
|
- labels_list (list[Tensor]): Labels for all images. |
|
- label_weights_list (list[Tensor]): Label weights for all images. |
|
- bbox_targets_list (list[Tensor]): BBox targets for all images. |
|
- bbox_weights_list (list[Tensor]): BBox weights for all images. |
|
- num_total_pos (int): Number of positive samples in all images. |
|
- num_total_neg (int): Number of negative samples in all images. |
|
""" |
|
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, |
|
pos_inds_list, neg_inds_list) = multi_apply( |
|
self._get_dn_targets_single, |
|
batch_gt_instances, |
|
batch_img_metas, |
|
dn_meta=dn_meta) |
|
num_total_pos = sum((inds.numel() for inds in pos_inds_list)) |
|
num_total_neg = sum((inds.numel() for inds in neg_inds_list)) |
|
return (labels_list, label_weights_list, bbox_targets_list, |
|
bbox_weights_list, num_total_pos, num_total_neg) |
|
|
|
def _get_dn_targets_single(self, gt_instances: InstanceData, |
|
img_meta: dict, dn_meta: Dict[str, |
|
int]) -> tuple: |
|
"""Get targets in denoising part for one image. |
|
|
|
Args: |
|
gt_instances (:obj:`InstanceData`): Ground truth of instance |
|
annotations. It should includes ``bboxes`` and ``labels`` |
|
attributes. |
|
img_meta (dict): Meta information for one image. |
|
dn_meta (Dict[str, int]): The dictionary saves information about |
|
group collation, including 'num_denoising_queries' and |
|
'num_denoising_groups'. It will be used for split outputs of |
|
denoising and matching parts and loss calculation. |
|
|
|
Returns: |
|
tuple[Tensor]: a tuple containing the following for one image. |
|
|
|
- labels (Tensor): Labels of each image. |
|
- label_weights (Tensor]): Label weights of each image. |
|
- bbox_targets (Tensor): BBox targets of each image. |
|
- bbox_weights (Tensor): BBox weights of each image. |
|
- pos_inds (Tensor): Sampled positive indices for each image. |
|
- neg_inds (Tensor): Sampled negative indices for each image. |
|
""" |
|
gt_bboxes = gt_instances.bboxes |
|
gt_labels = gt_instances.labels |
|
num_groups = dn_meta['num_denoising_groups'] |
|
num_denoising_queries = dn_meta['num_denoising_queries'] |
|
num_queries_each_group = int(num_denoising_queries / num_groups) |
|
device = gt_bboxes.device |
|
|
|
if len(gt_labels) > 0: |
|
t = torch.arange(len(gt_labels), dtype=torch.long, device=device) |
|
t = t.unsqueeze(0).repeat(num_groups, 1) |
|
pos_assigned_gt_inds = t.flatten() |
|
pos_inds = torch.arange( |
|
num_groups, dtype=torch.long, device=device) |
|
pos_inds = pos_inds.unsqueeze(1) * num_queries_each_group + t |
|
pos_inds = pos_inds.flatten() |
|
else: |
|
pos_inds = pos_assigned_gt_inds = \ |
|
gt_bboxes.new_tensor([], dtype=torch.long) |
|
|
|
neg_inds = pos_inds + num_queries_each_group // 2 |
|
|
|
|
|
labels = gt_bboxes.new_full((num_denoising_queries, ), |
|
self.num_classes, |
|
dtype=torch.long) |
|
labels[pos_inds] = gt_labels[pos_assigned_gt_inds] |
|
label_weights = gt_bboxes.new_ones(num_denoising_queries) |
|
|
|
|
|
bbox_targets = torch.zeros(num_denoising_queries, 4, device=device) |
|
bbox_weights = torch.zeros(num_denoising_queries, 4, device=device) |
|
bbox_weights[pos_inds] = 1.0 |
|
img_h, img_w = img_meta['img_shape'] |
|
|
|
|
|
|
|
|
|
factor = gt_bboxes.new_tensor([img_w, img_h, img_w, |
|
img_h]).unsqueeze(0) |
|
gt_bboxes_normalized = gt_bboxes / factor |
|
gt_bboxes_targets = bbox_xyxy_to_cxcywh(gt_bboxes_normalized) |
|
bbox_targets[pos_inds] = gt_bboxes_targets.repeat([num_groups, 1]) |
|
|
|
return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, |
|
neg_inds) |
|
|
|
@staticmethod |
|
def split_outputs(all_layers_cls_scores: Tensor, |
|
all_layers_bbox_preds: Tensor, |
|
dn_meta: Dict[str, int]) -> Tuple[Tensor]: |
|
"""Split outputs of the denoising part and the matching part. |
|
|
|
For the total outputs of `num_queries_total` length, the former |
|
`num_denoising_queries` outputs are from denoising queries, and |
|
the rest `num_matching_queries` ones are from matching queries, |
|
where `num_queries_total` is the sum of `num_denoising_queries` and |
|
`num_matching_queries`. |
|
|
|
Args: |
|
all_layers_cls_scores (Tensor): Classification scores of all |
|
decoder layers, has shape (num_decoder_layers, bs, |
|
num_queries_total, cls_out_channels). |
|
all_layers_bbox_preds (Tensor): Regression outputs of all decoder |
|
layers. Each is a 4D-tensor with normalized coordinate format |
|
(cx, cy, w, h) and has shape (num_decoder_layers, bs, |
|
num_queries_total, 4). |
|
dn_meta (Dict[str, int]): The dictionary saves information about |
|
group collation, including 'num_denoising_queries' and |
|
'num_denoising_groups'. |
|
|
|
Returns: |
|
Tuple[Tensor]: a tuple containing the following outputs. |
|
|
|
- all_layers_matching_cls_scores (Tensor): Classification scores |
|
of all decoder layers in matching part, has shape |
|
(num_decoder_layers, bs, num_matching_queries, cls_out_channels). |
|
- all_layers_matching_bbox_preds (Tensor): Regression outputs of |
|
all decoder layers in matching part. Each is a 4D-tensor with |
|
normalized coordinate format (cx, cy, w, h) and has shape |
|
(num_decoder_layers, bs, num_matching_queries, 4). |
|
- all_layers_denoising_cls_scores (Tensor): Classification scores |
|
of all decoder layers in denoising part, has shape |
|
(num_decoder_layers, bs, num_denoising_queries, |
|
cls_out_channels). |
|
- all_layers_denoising_bbox_preds (Tensor): Regression outputs of |
|
all decoder layers in denoising part. Each is a 4D-tensor with |
|
normalized coordinate format (cx, cy, w, h) and has shape |
|
(num_decoder_layers, bs, num_denoising_queries, 4). |
|
""" |
|
num_denoising_queries = dn_meta['num_denoising_queries'] |
|
if dn_meta is not None: |
|
all_layers_denoising_cls_scores = \ |
|
all_layers_cls_scores[:, :, : num_denoising_queries, :] |
|
all_layers_denoising_bbox_preds = \ |
|
all_layers_bbox_preds[:, :, : num_denoising_queries, :] |
|
all_layers_matching_cls_scores = \ |
|
all_layers_cls_scores[:, :, num_denoising_queries:, :] |
|
all_layers_matching_bbox_preds = \ |
|
all_layers_bbox_preds[:, :, num_denoising_queries:, :] |
|
else: |
|
all_layers_denoising_cls_scores = None |
|
all_layers_denoising_bbox_preds = None |
|
all_layers_matching_cls_scores = all_layers_cls_scores |
|
all_layers_matching_bbox_preds = all_layers_bbox_preds |
|
return (all_layers_matching_cls_scores, all_layers_matching_bbox_preds, |
|
all_layers_denoising_cls_scores, |
|
all_layers_denoising_bbox_preds) |
|
|