|
|
|
import copy |
|
from typing import Dict, List, Tuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
from mmcv.cnn import Linear |
|
from mmengine.model import bias_init_with_prob, constant_init |
|
from torch import Tensor |
|
|
|
from mmdet.registry import MODELS |
|
from mmdet.structures import SampleList |
|
from mmdet.utils import InstanceList, OptInstanceList |
|
from ..layers import inverse_sigmoid |
|
from .detr_head import DETRHead |
|
|
|
|
|
@MODELS.register_module() |
|
class DeformableDETRHead(DETRHead): |
|
r"""Head of DeformDETR: Deformable DETR: Deformable Transformers for |
|
End-to-End Object Detection. |
|
|
|
Code is modified from the `official github repo |
|
<https://github.com/fundamentalvision/Deformable-DETR>`_. |
|
|
|
More details can be found in the `paper |
|
<https://arxiv.org/abs/2010.04159>`_ . |
|
|
|
Args: |
|
share_pred_layer (bool): Whether to share parameters for all the |
|
prediction layers. Defaults to `False`. |
|
num_pred_layer (int): The number of the prediction layers. |
|
Defaults to 6. |
|
as_two_stage (bool, optional): Whether to generate the proposal |
|
from the outputs of encoder. Defaults to `False`. |
|
""" |
|
|
|
def __init__(self, |
|
*args, |
|
share_pred_layer: bool = False, |
|
num_pred_layer: int = 6, |
|
as_two_stage: bool = False, |
|
**kwargs) -> None: |
|
self.share_pred_layer = share_pred_layer |
|
self.num_pred_layer = num_pred_layer |
|
self.as_two_stage = as_two_stage |
|
|
|
super().__init__(*args, **kwargs) |
|
|
|
def _init_layers(self) -> None: |
|
"""Initialize classification branch and regression branch of head.""" |
|
fc_cls = Linear(self.embed_dims, self.cls_out_channels) |
|
reg_branch = [] |
|
for _ in range(self.num_reg_fcs): |
|
reg_branch.append(Linear(self.embed_dims, self.embed_dims)) |
|
reg_branch.append(nn.ReLU()) |
|
reg_branch.append(Linear(self.embed_dims, 4)) |
|
reg_branch = nn.Sequential(*reg_branch) |
|
|
|
if self.share_pred_layer: |
|
self.cls_branches = nn.ModuleList( |
|
[fc_cls for _ in range(self.num_pred_layer)]) |
|
self.reg_branches = nn.ModuleList( |
|
[reg_branch for _ in range(self.num_pred_layer)]) |
|
else: |
|
self.cls_branches = nn.ModuleList( |
|
[copy.deepcopy(fc_cls) for _ in range(self.num_pred_layer)]) |
|
self.reg_branches = nn.ModuleList([ |
|
copy.deepcopy(reg_branch) for _ in range(self.num_pred_layer) |
|
]) |
|
|
|
def init_weights(self) -> None: |
|
"""Initialize weights of the Deformable DETR head.""" |
|
if self.loss_cls.use_sigmoid: |
|
bias_init = bias_init_with_prob(0.01) |
|
for m in self.cls_branches: |
|
nn.init.constant_(m.bias, bias_init) |
|
for m in self.reg_branches: |
|
constant_init(m[-1], 0, bias=0) |
|
nn.init.constant_(self.reg_branches[0][-1].bias.data[2:], -2.0) |
|
if self.as_two_stage: |
|
for m in self.reg_branches: |
|
nn.init.constant_(m[-1].bias.data[2:], 0.0) |
|
|
|
def forward(self, hidden_states: Tensor, |
|
references: List[Tensor]) -> Tuple[Tensor]: |
|
"""Forward function. |
|
|
|
Args: |
|
hidden_states (Tensor): Hidden states output from each decoder |
|
layer, has shape (num_decoder_layers, bs, num_queries, dim). |
|
references (list[Tensor]): List of the reference from the decoder. |
|
The first reference is the `init_reference` (initial) and the |
|
other num_decoder_layers(6) references are `inter_references` |
|
(intermediate). The `init_reference` has shape (bs, |
|
num_queries, 4) when `as_two_stage` of the detector is `True`, |
|
otherwise (bs, num_queries, 2). Each `inter_reference` has |
|
shape (bs, num_queries, 4) when `with_box_refine` of the |
|
detector is `True`, otherwise (bs, num_queries, 2). The |
|
coordinates are arranged as (cx, cy) when the last dimension is |
|
2, and (cx, cy, w, h) when it is 4. |
|
|
|
Returns: |
|
tuple[Tensor]: results of head containing the following tensor. |
|
|
|
- all_layers_outputs_classes (Tensor): Outputs from the |
|
classification head, has shape (num_decoder_layers, bs, |
|
num_queries, cls_out_channels). |
|
- all_layers_outputs_coords (Tensor): Sigmoid outputs from the |
|
regression head with normalized coordinate format (cx, cy, w, |
|
h), has shape (num_decoder_layers, bs, num_queries, 4) with the |
|
last dimension arranged as (cx, cy, w, h). |
|
""" |
|
all_layers_outputs_classes = [] |
|
all_layers_outputs_coords = [] |
|
|
|
for layer_id in range(hidden_states.shape[0]): |
|
reference = inverse_sigmoid(references[layer_id]) |
|
|
|
hidden_state = hidden_states[layer_id] |
|
outputs_class = self.cls_branches[layer_id](hidden_state) |
|
tmp_reg_preds = self.reg_branches[layer_id](hidden_state) |
|
if reference.shape[-1] == 4: |
|
|
|
|
|
|
|
tmp_reg_preds += reference |
|
else: |
|
|
|
|
|
|
|
assert reference.shape[-1] == 2 |
|
tmp_reg_preds[..., :2] += reference |
|
outputs_coord = tmp_reg_preds.sigmoid() |
|
all_layers_outputs_classes.append(outputs_class) |
|
all_layers_outputs_coords.append(outputs_coord) |
|
|
|
all_layers_outputs_classes = torch.stack(all_layers_outputs_classes) |
|
all_layers_outputs_coords = torch.stack(all_layers_outputs_coords) |
|
|
|
return all_layers_outputs_classes, all_layers_outputs_coords |
|
|
|
def loss(self, hidden_states: Tensor, references: List[Tensor], |
|
enc_outputs_class: Tensor, enc_outputs_coord: Tensor, |
|
batch_data_samples: SampleList) -> dict: |
|
"""Perform forward propagation and loss calculation of the detection |
|
head on the queries of the upstream network. |
|
|
|
Args: |
|
hidden_states (Tensor): Hidden states output from each decoder |
|
layer, has shape (num_decoder_layers, num_queries, bs, dim). |
|
references (list[Tensor]): List of the reference from the decoder. |
|
The first reference is the `init_reference` (initial) and the |
|
other num_decoder_layers(6) references are `inter_references` |
|
(intermediate). The `init_reference` has shape (bs, |
|
num_queries, 4) when `as_two_stage` of the detector is `True`, |
|
otherwise (bs, num_queries, 2). Each `inter_reference` has |
|
shape (bs, num_queries, 4) when `with_box_refine` of the |
|
detector is `True`, otherwise (bs, num_queries, 2). The |
|
coordinates are arranged as (cx, cy) when the last dimension is |
|
2, and (cx, cy, w, h) when it is 4. |
|
enc_outputs_class (Tensor): The score of each point on encode |
|
feature map, has shape (bs, num_feat_points, cls_out_channels). |
|
Only when `as_two_stage` is `True` it would be passed in, |
|
otherwise it would be `None`. |
|
enc_outputs_coord (Tensor): The proposal generate from the encode |
|
feature map, has shape (bs, num_feat_points, 4) with the last |
|
dimension arranged as (cx, cy, w, h). Only when `as_two_stage` |
|
is `True` it would be passed in, otherwise it would be `None`. |
|
batch_data_samples (list[:obj:`DetDataSample`]): The Data |
|
Samples. It usually includes information such as |
|
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. |
|
|
|
Returns: |
|
dict: A dictionary of loss components. |
|
""" |
|
batch_gt_instances = [] |
|
batch_img_metas = [] |
|
for data_sample in batch_data_samples: |
|
batch_img_metas.append(data_sample.metainfo) |
|
batch_gt_instances.append(data_sample.gt_instances) |
|
|
|
outs = self(hidden_states, references) |
|
loss_inputs = outs + (enc_outputs_class, enc_outputs_coord, |
|
batch_gt_instances, batch_img_metas) |
|
losses = self.loss_by_feat(*loss_inputs) |
|
return losses |
|
|
|
def loss_by_feat( |
|
self, |
|
all_layers_cls_scores: Tensor, |
|
all_layers_bbox_preds: Tensor, |
|
enc_cls_scores: Tensor, |
|
enc_bbox_preds: Tensor, |
|
batch_gt_instances: InstanceList, |
|
batch_img_metas: List[dict], |
|
batch_gt_instances_ignore: OptInstanceList = None |
|
) -> Dict[str, Tensor]: |
|
"""Loss function. |
|
|
|
Args: |
|
all_layers_cls_scores (Tensor): Classification scores of all |
|
decoder layers, has shape (num_decoder_layers, bs, num_queries, |
|
cls_out_channels). |
|
all_layers_bbox_preds (Tensor): Regression outputs of all decoder |
|
layers. Each is a 4D-tensor with normalized coordinate format |
|
(cx, cy, w, h) and has shape (num_decoder_layers, bs, |
|
num_queries, 4) with the last dimension arranged as |
|
(cx, cy, w, h). |
|
enc_cls_scores (Tensor): The score of each point on encode |
|
feature map, has shape (bs, num_feat_points, cls_out_channels). |
|
Only when `as_two_stage` is `True` it would be passes in, |
|
otherwise, it would be `None`. |
|
enc_bbox_preds (Tensor): The proposal generate from the encode |
|
feature map, has shape (bs, num_feat_points, 4) with the last |
|
dimension arranged as (cx, cy, w, h). Only when `as_two_stage` |
|
is `True` it would be passed in, otherwise it would be `None`. |
|
batch_gt_instances (list[:obj:`InstanceData`]): Batch of |
|
gt_instance. It usually includes ``bboxes`` and ``labels`` |
|
attributes. |
|
batch_img_metas (list[dict]): Meta information of each image, e.g., |
|
image size, scaling factor, etc. |
|
batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): |
|
Batch of gt_instances_ignore. It includes ``bboxes`` attribute |
|
data that is ignored during training and testing. |
|
Defaults to None. |
|
|
|
Returns: |
|
dict[str, Tensor]: A dictionary of loss components. |
|
""" |
|
loss_dict = super().loss_by_feat(all_layers_cls_scores, |
|
all_layers_bbox_preds, |
|
batch_gt_instances, batch_img_metas, |
|
batch_gt_instances_ignore) |
|
|
|
|
|
if enc_cls_scores is not None: |
|
proposal_gt_instances = copy.deepcopy(batch_gt_instances) |
|
for i in range(len(proposal_gt_instances)): |
|
proposal_gt_instances[i].labels = torch.zeros_like( |
|
proposal_gt_instances[i].labels) |
|
enc_loss_cls, enc_losses_bbox, enc_losses_iou = \ |
|
self.loss_by_feat_single( |
|
enc_cls_scores, enc_bbox_preds, |
|
batch_gt_instances=proposal_gt_instances, |
|
batch_img_metas=batch_img_metas) |
|
loss_dict['enc_loss_cls'] = enc_loss_cls |
|
loss_dict['enc_loss_bbox'] = enc_losses_bbox |
|
loss_dict['enc_loss_iou'] = enc_losses_iou |
|
return loss_dict |
|
|
|
def predict(self, |
|
hidden_states: Tensor, |
|
references: List[Tensor], |
|
batch_data_samples: SampleList, |
|
rescale: bool = True) -> InstanceList: |
|
"""Perform forward propagation and loss calculation of the detection |
|
head on the queries of the upstream network. |
|
|
|
Args: |
|
hidden_states (Tensor): Hidden states output from each decoder |
|
layer, has shape (num_decoder_layers, num_queries, bs, dim). |
|
references (list[Tensor]): List of the reference from the decoder. |
|
The first reference is the `init_reference` (initial) and the |
|
other num_decoder_layers(6) references are `inter_references` |
|
(intermediate). The `init_reference` has shape (bs, |
|
num_queries, 4) when `as_two_stage` of the detector is `True`, |
|
otherwise (bs, num_queries, 2). Each `inter_reference` has |
|
shape (bs, num_queries, 4) when `with_box_refine` of the |
|
detector is `True`, otherwise (bs, num_queries, 2). The |
|
coordinates are arranged as (cx, cy) when the last dimension is |
|
2, and (cx, cy, w, h) when it is 4. |
|
batch_data_samples (list[:obj:`DetDataSample`]): The Data |
|
Samples. It usually includes information such as |
|
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. |
|
rescale (bool, optional): If `True`, return boxes in original |
|
image space. Defaults to `True`. |
|
|
|
Returns: |
|
list[obj:`InstanceData`]: Detection results of each image |
|
after the post process. |
|
""" |
|
batch_img_metas = [ |
|
data_samples.metainfo for data_samples in batch_data_samples |
|
] |
|
|
|
outs = self(hidden_states, references) |
|
|
|
predictions = self.predict_by_feat( |
|
*outs, batch_img_metas=batch_img_metas, rescale=rescale) |
|
return predictions |
|
|
|
def predict_by_feat(self, |
|
all_layers_cls_scores: Tensor, |
|
all_layers_bbox_preds: Tensor, |
|
batch_img_metas: List[Dict], |
|
rescale: bool = False) -> InstanceList: |
|
"""Transform a batch of output features extracted from the head into |
|
bbox results. |
|
|
|
Args: |
|
all_layers_cls_scores (Tensor): Classification scores of all |
|
decoder layers, has shape (num_decoder_layers, bs, num_queries, |
|
cls_out_channels). |
|
all_layers_bbox_preds (Tensor): Regression outputs of all decoder |
|
layers. Each is a 4D-tensor with normalized coordinate format |
|
(cx, cy, w, h) and shape (num_decoder_layers, bs, num_queries, |
|
4) with the last dimension arranged as (cx, cy, w, h). |
|
batch_img_metas (list[dict]): Meta information of each image. |
|
rescale (bool, optional): If `True`, return boxes in original |
|
image space. Default `False`. |
|
|
|
Returns: |
|
list[obj:`InstanceData`]: Detection results of each image |
|
after the post process. |
|
""" |
|
cls_scores = all_layers_cls_scores[-1] |
|
bbox_preds = all_layers_bbox_preds[-1] |
|
|
|
result_list = [] |
|
for img_id in range(len(batch_img_metas)): |
|
cls_score = cls_scores[img_id] |
|
bbox_pred = bbox_preds[img_id] |
|
img_meta = batch_img_metas[img_id] |
|
results = self._predict_by_feat_single(cls_score, bbox_pred, |
|
img_meta, rescale) |
|
result_list.append(results) |
|
return result_list |
|
|