|
|
|
import warnings |
|
from typing import List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
from mmengine.structures import InstanceData |
|
from torch import Tensor |
|
|
|
from mmdet.registry import MODELS, TASK_UTILS |
|
from mmdet.structures.bbox import BaseBoxes, cat_boxes, get_box_tensor |
|
from mmdet.utils import (ConfigType, InstanceList, OptConfigType, |
|
OptInstanceList, OptMultiConfig) |
|
from ..task_modules.prior_generators import (AnchorGenerator, |
|
anchor_inside_flags) |
|
from ..task_modules.samplers import PseudoSampler |
|
from ..utils import images_to_levels, multi_apply, unmap |
|
from .base_dense_head import BaseDenseHead |
|
|
|
|
|
@MODELS.register_module() |
|
class AnchorHead(BaseDenseHead): |
|
"""Anchor-based head (RPN, RetinaNet, SSD, etc.). |
|
|
|
Args: |
|
num_classes (int): Number of categories excluding the background |
|
category. |
|
in_channels (int): Number of channels in the input feature map. |
|
feat_channels (int): Number of hidden channels. Used in child classes. |
|
anchor_generator (dict): Config dict for anchor generator |
|
bbox_coder (dict): Config of bounding box coder. |
|
reg_decoded_bbox (bool): If true, the regression loss would be |
|
applied directly on decoded bounding boxes, converting both |
|
the predicted boxes and regression targets to absolute |
|
coordinates format. Default False. It should be `True` when |
|
using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head. |
|
loss_cls (dict): Config of classification loss. |
|
loss_bbox (dict): Config of localization loss. |
|
train_cfg (dict): Training config of anchor head. |
|
test_cfg (dict): Testing config of anchor head. |
|
init_cfg (dict or list[dict], optional): Initialization config dict. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
num_classes: int, |
|
in_channels: int, |
|
feat_channels: int = 256, |
|
anchor_generator: ConfigType = dict( |
|
type='AnchorGenerator', |
|
scales=[8, 16, 32], |
|
ratios=[0.5, 1.0, 2.0], |
|
strides=[4, 8, 16, 32, 64]), |
|
bbox_coder: ConfigType = dict( |
|
type='DeltaXYWHBBoxCoder', |
|
clip_border=True, |
|
target_means=(.0, .0, .0, .0), |
|
target_stds=(1.0, 1.0, 1.0, 1.0)), |
|
reg_decoded_bbox: bool = False, |
|
loss_cls: ConfigType = dict( |
|
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), |
|
loss_bbox: ConfigType = dict( |
|
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0), |
|
train_cfg: OptConfigType = None, |
|
test_cfg: OptConfigType = None, |
|
init_cfg: OptMultiConfig = dict( |
|
type='Normal', layer='Conv2d', std=0.01) |
|
) -> None: |
|
super().__init__(init_cfg=init_cfg) |
|
self.in_channels = in_channels |
|
self.num_classes = num_classes |
|
self.feat_channels = feat_channels |
|
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False) |
|
if self.use_sigmoid_cls: |
|
self.cls_out_channels = num_classes |
|
else: |
|
self.cls_out_channels = num_classes + 1 |
|
|
|
if self.cls_out_channels <= 0: |
|
raise ValueError(f'num_classes={num_classes} is too small') |
|
self.reg_decoded_bbox = reg_decoded_bbox |
|
|
|
self.bbox_coder = TASK_UTILS.build(bbox_coder) |
|
self.loss_cls = MODELS.build(loss_cls) |
|
self.loss_bbox = MODELS.build(loss_bbox) |
|
self.train_cfg = train_cfg |
|
self.test_cfg = test_cfg |
|
if self.train_cfg: |
|
self.assigner = TASK_UTILS.build(self.train_cfg['assigner']) |
|
if train_cfg.get('sampler', None) is not None: |
|
self.sampler = TASK_UTILS.build( |
|
self.train_cfg['sampler'], default_args=dict(context=self)) |
|
else: |
|
self.sampler = PseudoSampler(context=self) |
|
|
|
self.fp16_enabled = False |
|
|
|
self.prior_generator = TASK_UTILS.build(anchor_generator) |
|
|
|
|
|
|
|
|
|
self.num_base_priors = self.prior_generator.num_base_priors[0] |
|
self._init_layers() |
|
|
|
@property |
|
def num_anchors(self) -> int: |
|
warnings.warn('DeprecationWarning: `num_anchors` is deprecated, ' |
|
'for consistency or also use ' |
|
'`num_base_priors` instead') |
|
return self.prior_generator.num_base_priors[0] |
|
|
|
@property |
|
def anchor_generator(self) -> AnchorGenerator: |
|
warnings.warn('DeprecationWarning: anchor_generator is deprecated, ' |
|
'please use "prior_generator" instead') |
|
return self.prior_generator |
|
|
|
def _init_layers(self) -> None: |
|
"""Initialize layers of the head.""" |
|
self.conv_cls = nn.Conv2d(self.in_channels, |
|
self.num_base_priors * self.cls_out_channels, |
|
1) |
|
reg_dim = self.bbox_coder.encode_size |
|
self.conv_reg = nn.Conv2d(self.in_channels, |
|
self.num_base_priors * reg_dim, 1) |
|
|
|
def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor]: |
|
"""Forward feature of a single scale level. |
|
|
|
Args: |
|
x (Tensor): Features of a single scale level. |
|
|
|
Returns: |
|
tuple: |
|
cls_score (Tensor): Cls scores for a single scale level \ |
|
the channels number is num_base_priors * num_classes. |
|
bbox_pred (Tensor): Box energies / deltas for a single scale \ |
|
level, the channels number is num_base_priors * 4. |
|
""" |
|
cls_score = self.conv_cls(x) |
|
bbox_pred = self.conv_reg(x) |
|
return cls_score, bbox_pred |
|
|
|
def forward(self, x: Tuple[Tensor]) -> Tuple[List[Tensor]]: |
|
"""Forward features from the upstream network. |
|
|
|
Args: |
|
x (tuple[Tensor]): Features from the upstream network, each is |
|
a 4D-tensor. |
|
|
|
Returns: |
|
tuple: A tuple of classification scores and bbox prediction. |
|
|
|
- cls_scores (list[Tensor]): Classification scores for all \ |
|
scale levels, each is a 4D-tensor, the channels number \ |
|
is num_base_priors * num_classes. |
|
- bbox_preds (list[Tensor]): Box energies / deltas for all \ |
|
scale levels, each is a 4D-tensor, the channels number \ |
|
is num_base_priors * 4. |
|
""" |
|
return multi_apply(self.forward_single, x) |
|
|
|
def get_anchors(self, |
|
featmap_sizes: List[tuple], |
|
batch_img_metas: List[dict], |
|
device: Union[torch.device, str] = 'cuda') \ |
|
-> Tuple[List[List[Tensor]], List[List[Tensor]]]: |
|
"""Get anchors according to feature map sizes. |
|
|
|
Args: |
|
featmap_sizes (list[tuple]): Multi-level feature map sizes. |
|
batch_img_metas (list[dict]): Image meta info. |
|
device (torch.device | str): Device for returned tensors. |
|
Defaults to cuda. |
|
|
|
Returns: |
|
tuple: |
|
|
|
- anchor_list (list[list[Tensor]]): Anchors of each image. |
|
- valid_flag_list (list[list[Tensor]]): Valid flags of each |
|
image. |
|
""" |
|
num_imgs = len(batch_img_metas) |
|
|
|
|
|
|
|
multi_level_anchors = self.prior_generator.grid_priors( |
|
featmap_sizes, device=device) |
|
anchor_list = [multi_level_anchors for _ in range(num_imgs)] |
|
|
|
|
|
valid_flag_list = [] |
|
for img_id, img_meta in enumerate(batch_img_metas): |
|
multi_level_flags = self.prior_generator.valid_flags( |
|
featmap_sizes, img_meta['pad_shape'], device) |
|
valid_flag_list.append(multi_level_flags) |
|
|
|
return anchor_list, valid_flag_list |
|
|
|
def _get_targets_single(self, |
|
flat_anchors: Union[Tensor, BaseBoxes], |
|
valid_flags: Tensor, |
|
gt_instances: InstanceData, |
|
img_meta: dict, |
|
gt_instances_ignore: Optional[InstanceData] = None, |
|
unmap_outputs: bool = True) -> tuple: |
|
"""Compute regression and classification targets for anchors in a |
|
single image. |
|
|
|
Args: |
|
flat_anchors (Tensor or :obj:`BaseBoxes`): Multi-level anchors |
|
of the image, which are concatenated into a single tensor |
|
or box type of shape (num_anchors, 4) |
|
valid_flags (Tensor): Multi level valid flags of the image, |
|
which are concatenated into a single tensor of |
|
shape (num_anchors, ). |
|
gt_instances (:obj:`InstanceData`): Ground truth of instance |
|
annotations. It should includes ``bboxes`` and ``labels`` |
|
attributes. |
|
img_meta (dict): Meta information for current image. |
|
gt_instances_ignore (:obj:`InstanceData`, optional): Instances |
|
to be ignored during training. It includes ``bboxes`` attribute |
|
data that is ignored during training and testing. |
|
Defaults to None. |
|
unmap_outputs (bool): Whether to map outputs back to the original |
|
set of anchors. Defaults to True. |
|
|
|
Returns: |
|
tuple: |
|
|
|
- labels (Tensor): Labels of each level. |
|
- label_weights (Tensor): Label weights of each level. |
|
- bbox_targets (Tensor): BBox targets of each level. |
|
- bbox_weights (Tensor): BBox weights of each level. |
|
- pos_inds (Tensor): positive samples indexes. |
|
- neg_inds (Tensor): negative samples indexes. |
|
- sampling_result (:obj:`SamplingResult`): Sampling results. |
|
""" |
|
inside_flags = anchor_inside_flags(flat_anchors, valid_flags, |
|
img_meta['img_shape'][:2], |
|
self.train_cfg['allowed_border']) |
|
if not inside_flags.any(): |
|
raise ValueError( |
|
'There is no valid anchor inside the image boundary. Please ' |
|
'check the image size and anchor sizes, or set ' |
|
'``allowed_border`` to -1 to skip the condition.') |
|
|
|
anchors = flat_anchors[inside_flags] |
|
|
|
pred_instances = InstanceData(priors=anchors) |
|
assign_result = self.assigner.assign(pred_instances, gt_instances, |
|
gt_instances_ignore) |
|
|
|
|
|
sampling_result = self.sampler.sample(assign_result, pred_instances, |
|
gt_instances) |
|
|
|
num_valid_anchors = anchors.shape[0] |
|
target_dim = gt_instances.bboxes.size(-1) if self.reg_decoded_bbox \ |
|
else self.bbox_coder.encode_size |
|
bbox_targets = anchors.new_zeros(num_valid_anchors, target_dim) |
|
bbox_weights = anchors.new_zeros(num_valid_anchors, target_dim) |
|
|
|
|
|
labels = anchors.new_full((num_valid_anchors, ), |
|
self.num_classes, |
|
dtype=torch.long) |
|
label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float) |
|
|
|
pos_inds = sampling_result.pos_inds |
|
neg_inds = sampling_result.neg_inds |
|
|
|
|
|
|
|
if len(pos_inds) > 0: |
|
if not self.reg_decoded_bbox: |
|
pos_bbox_targets = self.bbox_coder.encode( |
|
sampling_result.pos_priors, sampling_result.pos_gt_bboxes) |
|
else: |
|
pos_bbox_targets = sampling_result.pos_gt_bboxes |
|
pos_bbox_targets = get_box_tensor(pos_bbox_targets) |
|
bbox_targets[pos_inds, :] = pos_bbox_targets |
|
bbox_weights[pos_inds, :] = 1.0 |
|
|
|
labels[pos_inds] = sampling_result.pos_gt_labels |
|
if self.train_cfg['pos_weight'] <= 0: |
|
label_weights[pos_inds] = 1.0 |
|
else: |
|
label_weights[pos_inds] = self.train_cfg['pos_weight'] |
|
if len(neg_inds) > 0: |
|
label_weights[neg_inds] = 1.0 |
|
|
|
|
|
if unmap_outputs: |
|
num_total_anchors = flat_anchors.size(0) |
|
labels = unmap( |
|
labels, num_total_anchors, inside_flags, |
|
fill=self.num_classes) |
|
label_weights = unmap(label_weights, num_total_anchors, |
|
inside_flags) |
|
bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags) |
|
bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags) |
|
|
|
return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, |
|
neg_inds, sampling_result) |
|
|
|
def get_targets(self, |
|
anchor_list: List[List[Tensor]], |
|
valid_flag_list: List[List[Tensor]], |
|
batch_gt_instances: InstanceList, |
|
batch_img_metas: List[dict], |
|
batch_gt_instances_ignore: OptInstanceList = None, |
|
unmap_outputs: bool = True, |
|
return_sampling_results: bool = False) -> tuple: |
|
"""Compute regression and classification targets for anchors in |
|
multiple images. |
|
|
|
Args: |
|
anchor_list (list[list[Tensor]]): Multi level anchors of each |
|
image. The outer list indicates images, and the inner list |
|
corresponds to feature levels of the image. Each element of |
|
the inner list is a tensor of shape (num_anchors, 4). |
|
valid_flag_list (list[list[Tensor]]): Multi level valid flags of |
|
each image. The outer list indicates images, and the inner list |
|
corresponds to feature levels of the image. Each element of |
|
the inner list is a tensor of shape (num_anchors, ) |
|
batch_gt_instances (list[:obj:`InstanceData`]): Batch of |
|
gt_instance. It usually includes ``bboxes`` and ``labels`` |
|
attributes. |
|
batch_img_metas (list[dict]): Meta information of each image, e.g., |
|
image size, scaling factor, etc. |
|
batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): |
|
Batch of gt_instances_ignore. It includes ``bboxes`` attribute |
|
data that is ignored during training and testing. |
|
Defaults to None. |
|
unmap_outputs (bool): Whether to map outputs back to the original |
|
set of anchors. Defaults to True. |
|
return_sampling_results (bool): Whether to return the sampling |
|
results. Defaults to False. |
|
|
|
Returns: |
|
tuple: Usually returns a tuple containing learning targets. |
|
|
|
- labels_list (list[Tensor]): Labels of each level. |
|
- label_weights_list (list[Tensor]): Label weights of each |
|
level. |
|
- bbox_targets_list (list[Tensor]): BBox targets of each level. |
|
- bbox_weights_list (list[Tensor]): BBox weights of each level. |
|
- avg_factor (int): Average factor that is used to average |
|
the loss. When using sampling method, avg_factor is usually |
|
the sum of positive and negative priors. When using |
|
`PseudoSampler`, `avg_factor` is usually equal to the number |
|
of positive priors. |
|
|
|
additional_returns: This function enables user-defined returns from |
|
`self._get_targets_single`. These returns are currently refined |
|
to properties at each feature map (i.e. having HxW dimension). |
|
The results will be concatenated after the end |
|
""" |
|
num_imgs = len(batch_img_metas) |
|
assert len(anchor_list) == len(valid_flag_list) == num_imgs |
|
|
|
if batch_gt_instances_ignore is None: |
|
batch_gt_instances_ignore = [None] * num_imgs |
|
|
|
|
|
num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] |
|
|
|
concat_anchor_list = [] |
|
concat_valid_flag_list = [] |
|
for i in range(num_imgs): |
|
assert len(anchor_list[i]) == len(valid_flag_list[i]) |
|
concat_anchor_list.append(cat_boxes(anchor_list[i])) |
|
concat_valid_flag_list.append(torch.cat(valid_flag_list[i])) |
|
|
|
|
|
results = multi_apply( |
|
self._get_targets_single, |
|
concat_anchor_list, |
|
concat_valid_flag_list, |
|
batch_gt_instances, |
|
batch_img_metas, |
|
batch_gt_instances_ignore, |
|
unmap_outputs=unmap_outputs) |
|
(all_labels, all_label_weights, all_bbox_targets, all_bbox_weights, |
|
pos_inds_list, neg_inds_list, sampling_results_list) = results[:7] |
|
rest_results = list(results[7:]) |
|
|
|
|
|
|
|
|
|
avg_factor = sum( |
|
[results.avg_factor for results in sampling_results_list]) |
|
|
|
|
|
self._raw_positive_infos.update(sampling_results=sampling_results_list) |
|
|
|
labels_list = images_to_levels(all_labels, num_level_anchors) |
|
label_weights_list = images_to_levels(all_label_weights, |
|
num_level_anchors) |
|
bbox_targets_list = images_to_levels(all_bbox_targets, |
|
num_level_anchors) |
|
bbox_weights_list = images_to_levels(all_bbox_weights, |
|
num_level_anchors) |
|
res = (labels_list, label_weights_list, bbox_targets_list, |
|
bbox_weights_list, avg_factor) |
|
if return_sampling_results: |
|
res = res + (sampling_results_list, ) |
|
for i, r in enumerate(rest_results): |
|
rest_results[i] = images_to_levels(r, num_level_anchors) |
|
|
|
return res + tuple(rest_results) |
|
|
|
def loss_by_feat_single(self, cls_score: Tensor, bbox_pred: Tensor, |
|
anchors: Tensor, labels: Tensor, |
|
label_weights: Tensor, bbox_targets: Tensor, |
|
bbox_weights: Tensor, avg_factor: int) -> tuple: |
|
"""Calculate the loss of a single scale level based on the features |
|
extracted by the detection head. |
|
|
|
Args: |
|
cls_score (Tensor): Box scores for each scale level |
|
Has shape (N, num_anchors * num_classes, H, W). |
|
bbox_pred (Tensor): Box energies / deltas for each scale |
|
level with shape (N, num_anchors * 4, H, W). |
|
anchors (Tensor): Box reference for each scale level with shape |
|
(N, num_total_anchors, 4). |
|
labels (Tensor): Labels of each anchors with shape |
|
(N, num_total_anchors). |
|
label_weights (Tensor): Label weights of each anchor with shape |
|
(N, num_total_anchors) |
|
bbox_targets (Tensor): BBox regression targets of each anchor |
|
weight shape (N, num_total_anchors, 4). |
|
bbox_weights (Tensor): BBox regression loss weights of each anchor |
|
with shape (N, num_total_anchors, 4). |
|
avg_factor (int): Average factor that is used to average the loss. |
|
|
|
Returns: |
|
tuple: loss components. |
|
""" |
|
|
|
labels = labels.reshape(-1) |
|
label_weights = label_weights.reshape(-1) |
|
cls_score = cls_score.permute(0, 2, 3, |
|
1).reshape(-1, self.cls_out_channels) |
|
loss_cls = self.loss_cls( |
|
cls_score, labels, label_weights, avg_factor=avg_factor) |
|
|
|
target_dim = bbox_targets.size(-1) |
|
bbox_targets = bbox_targets.reshape(-1, target_dim) |
|
bbox_weights = bbox_weights.reshape(-1, target_dim) |
|
bbox_pred = bbox_pred.permute(0, 2, 3, |
|
1).reshape(-1, |
|
self.bbox_coder.encode_size) |
|
if self.reg_decoded_bbox: |
|
|
|
|
|
|
|
anchors = anchors.reshape(-1, anchors.size(-1)) |
|
bbox_pred = self.bbox_coder.decode(anchors, bbox_pred) |
|
bbox_pred = get_box_tensor(bbox_pred) |
|
loss_bbox = self.loss_bbox( |
|
bbox_pred, bbox_targets, bbox_weights, avg_factor=avg_factor) |
|
return loss_cls, loss_bbox |
|
|
|
def loss_by_feat( |
|
self, |
|
cls_scores: List[Tensor], |
|
bbox_preds: List[Tensor], |
|
batch_gt_instances: InstanceList, |
|
batch_img_metas: List[dict], |
|
batch_gt_instances_ignore: OptInstanceList = None) -> dict: |
|
"""Calculate the loss based on the features extracted by the detection |
|
head. |
|
|
|
Args: |
|
cls_scores (list[Tensor]): Box scores for each scale level |
|
has shape (N, num_anchors * num_classes, H, W). |
|
bbox_preds (list[Tensor]): Box energies / deltas for each scale |
|
level with shape (N, num_anchors * 4, H, W). |
|
batch_gt_instances (list[:obj:`InstanceData`]): Batch of |
|
gt_instance. It usually includes ``bboxes`` and ``labels`` |
|
attributes. |
|
batch_img_metas (list[dict]): Meta information of each image, e.g., |
|
image size, scaling factor, etc. |
|
batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): |
|
Batch of gt_instances_ignore. It includes ``bboxes`` attribute |
|
data that is ignored during training and testing. |
|
Defaults to None. |
|
|
|
Returns: |
|
dict: A dictionary of loss components. |
|
""" |
|
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] |
|
assert len(featmap_sizes) == self.prior_generator.num_levels |
|
|
|
device = cls_scores[0].device |
|
|
|
anchor_list, valid_flag_list = self.get_anchors( |
|
featmap_sizes, batch_img_metas, device=device) |
|
cls_reg_targets = self.get_targets( |
|
anchor_list, |
|
valid_flag_list, |
|
batch_gt_instances, |
|
batch_img_metas, |
|
batch_gt_instances_ignore=batch_gt_instances_ignore) |
|
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, |
|
avg_factor) = cls_reg_targets |
|
|
|
|
|
num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] |
|
|
|
concat_anchor_list = [] |
|
for i in range(len(anchor_list)): |
|
concat_anchor_list.append(cat_boxes(anchor_list[i])) |
|
all_anchor_list = images_to_levels(concat_anchor_list, |
|
num_level_anchors) |
|
|
|
losses_cls, losses_bbox = multi_apply( |
|
self.loss_by_feat_single, |
|
cls_scores, |
|
bbox_preds, |
|
all_anchor_list, |
|
labels_list, |
|
label_weights_list, |
|
bbox_targets_list, |
|
bbox_weights_list, |
|
avg_factor=avg_factor) |
|
return dict(loss_cls=losses_cls, loss_bbox=losses_bbox) |
|
|