|
|
|
from abc import abstractmethod |
|
from typing import Any, List, Sequence, Tuple, Union |
|
|
|
import torch.nn as nn |
|
from mmcv.cnn import ConvModule |
|
from numpy import ndarray |
|
from torch import Tensor |
|
|
|
from mmdet.registry import MODELS, TASK_UTILS |
|
from mmdet.utils import (ConfigType, InstanceList, MultiConfig, OptConfigType, |
|
OptInstanceList) |
|
from ..task_modules.prior_generators import MlvlPointGenerator |
|
from ..utils import multi_apply |
|
from .base_dense_head import BaseDenseHead |
|
|
|
StrideType = Union[Sequence[int], Sequence[Tuple[int, int]]] |
|
|
|
|
|
@MODELS.register_module() |
|
class AnchorFreeHead(BaseDenseHead): |
|
"""Anchor-free head (FCOS, Fovea, RepPoints, etc.). |
|
|
|
Args: |
|
num_classes (int): Number of categories excluding the background |
|
category. |
|
in_channels (int): Number of channels in the input feature map. |
|
feat_channels (int): Number of hidden channels. Used in child classes. |
|
stacked_convs (int): Number of stacking convs of the head. |
|
strides (Sequence[int] or Sequence[Tuple[int, int]]): Downsample |
|
factor of each feature map. |
|
dcn_on_last_conv (bool): If true, use dcn in the last layer of |
|
towers. Defaults to False. |
|
conv_bias (bool or str): If specified as `auto`, it will be decided by |
|
the norm_cfg. Bias of conv will be set as True if `norm_cfg` is |
|
None, otherwise False. Default: "auto". |
|
loss_cls (:obj:`ConfigDict` or dict): Config of classification loss. |
|
loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss. |
|
bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder. Defaults |
|
'DistancePointBBoxCoder'. |
|
conv_cfg (:obj:`ConfigDict` or dict, Optional): Config dict for |
|
convolution layer. Defaults to None. |
|
norm_cfg (:obj:`ConfigDict` or dict, Optional): Config dict for |
|
normalization layer. Defaults to None. |
|
train_cfg (:obj:`ConfigDict` or dict, Optional): Training config of |
|
anchor-free head. |
|
test_cfg (:obj:`ConfigDict` or dict, Optional): Testing config of |
|
anchor-free head. |
|
init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ |
|
dict]): Initialization config dict. |
|
""" |
|
|
|
_version = 1 |
|
|
|
def __init__( |
|
self, |
|
num_classes: int, |
|
in_channels: int, |
|
feat_channels: int = 256, |
|
stacked_convs: int = 4, |
|
strides: StrideType = (4, 8, 16, 32, 64), |
|
dcn_on_last_conv: bool = False, |
|
conv_bias: Union[bool, str] = 'auto', |
|
loss_cls: ConfigType = dict( |
|
type='FocalLoss', |
|
use_sigmoid=True, |
|
gamma=2.0, |
|
alpha=0.25, |
|
loss_weight=1.0), |
|
loss_bbox: ConfigType = dict(type='IoULoss', loss_weight=1.0), |
|
bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'), |
|
conv_cfg: OptConfigType = None, |
|
norm_cfg: OptConfigType = None, |
|
train_cfg: OptConfigType = None, |
|
test_cfg: OptConfigType = None, |
|
init_cfg: MultiConfig = dict( |
|
type='Normal', |
|
layer='Conv2d', |
|
std=0.01, |
|
override=dict( |
|
type='Normal', name='conv_cls', std=0.01, bias_prob=0.01)) |
|
) -> None: |
|
super().__init__(init_cfg=init_cfg) |
|
self.num_classes = num_classes |
|
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False) |
|
if self.use_sigmoid_cls: |
|
self.cls_out_channels = num_classes |
|
else: |
|
self.cls_out_channels = num_classes + 1 |
|
self.in_channels = in_channels |
|
self.feat_channels = feat_channels |
|
self.stacked_convs = stacked_convs |
|
self.strides = strides |
|
self.dcn_on_last_conv = dcn_on_last_conv |
|
assert conv_bias == 'auto' or isinstance(conv_bias, bool) |
|
self.conv_bias = conv_bias |
|
self.loss_cls = MODELS.build(loss_cls) |
|
self.loss_bbox = MODELS.build(loss_bbox) |
|
self.bbox_coder = TASK_UTILS.build(bbox_coder) |
|
|
|
self.prior_generator = MlvlPointGenerator(strides) |
|
|
|
|
|
|
|
self.num_base_priors = self.prior_generator.num_base_priors[0] |
|
|
|
self.train_cfg = train_cfg |
|
self.test_cfg = test_cfg |
|
self.conv_cfg = conv_cfg |
|
self.norm_cfg = norm_cfg |
|
self.fp16_enabled = False |
|
|
|
self._init_layers() |
|
|
|
def _init_layers(self) -> None: |
|
"""Initialize layers of the head.""" |
|
self._init_cls_convs() |
|
self._init_reg_convs() |
|
self._init_predictor() |
|
|
|
def _init_cls_convs(self) -> None: |
|
"""Initialize classification conv layers of the head.""" |
|
self.cls_convs = nn.ModuleList() |
|
for i in range(self.stacked_convs): |
|
chn = self.in_channels if i == 0 else self.feat_channels |
|
if self.dcn_on_last_conv and i == self.stacked_convs - 1: |
|
conv_cfg = dict(type='DCNv2') |
|
else: |
|
conv_cfg = self.conv_cfg |
|
self.cls_convs.append( |
|
ConvModule( |
|
chn, |
|
self.feat_channels, |
|
3, |
|
stride=1, |
|
padding=1, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
bias=self.conv_bias)) |
|
|
|
def _init_reg_convs(self) -> None: |
|
"""Initialize bbox regression conv layers of the head.""" |
|
self.reg_convs = nn.ModuleList() |
|
for i in range(self.stacked_convs): |
|
chn = self.in_channels if i == 0 else self.feat_channels |
|
if self.dcn_on_last_conv and i == self.stacked_convs - 1: |
|
conv_cfg = dict(type='DCNv2') |
|
else: |
|
conv_cfg = self.conv_cfg |
|
self.reg_convs.append( |
|
ConvModule( |
|
chn, |
|
self.feat_channels, |
|
3, |
|
stride=1, |
|
padding=1, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
bias=self.conv_bias)) |
|
|
|
def _init_predictor(self) -> None: |
|
"""Initialize predictor layers of the head.""" |
|
self.conv_cls = nn.Conv2d( |
|
self.feat_channels, self.cls_out_channels, 3, padding=1) |
|
self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1) |
|
|
|
def _load_from_state_dict(self, state_dict: dict, prefix: str, |
|
local_metadata: dict, strict: bool, |
|
missing_keys: Union[List[str], str], |
|
unexpected_keys: Union[List[str], str], |
|
error_msgs: Union[List[str], str]) -> None: |
|
"""Hack some keys of the model state dict so that can load checkpoints |
|
of previous version.""" |
|
version = local_metadata.get('version', None) |
|
if version is None: |
|
|
|
|
|
bbox_head_keys = [ |
|
k for k in state_dict.keys() if k.startswith(prefix) |
|
] |
|
ori_predictor_keys = [] |
|
new_predictor_keys = [] |
|
|
|
for key in bbox_head_keys: |
|
ori_predictor_keys.append(key) |
|
key = key.split('.') |
|
if len(key) < 2: |
|
conv_name = None |
|
elif key[1].endswith('cls'): |
|
conv_name = 'conv_cls' |
|
elif key[1].endswith('reg'): |
|
conv_name = 'conv_reg' |
|
elif key[1].endswith('centerness'): |
|
conv_name = 'conv_centerness' |
|
else: |
|
conv_name = None |
|
if conv_name is not None: |
|
key[1] = conv_name |
|
new_predictor_keys.append('.'.join(key)) |
|
else: |
|
ori_predictor_keys.pop(-1) |
|
for i in range(len(new_predictor_keys)): |
|
state_dict[new_predictor_keys[i]] = state_dict.pop( |
|
ori_predictor_keys[i]) |
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, |
|
strict, missing_keys, unexpected_keys, |
|
error_msgs) |
|
|
|
def forward(self, x: Tuple[Tensor]) -> Tuple[List[Tensor], List[Tensor]]: |
|
"""Forward features from the upstream network. |
|
|
|
Args: |
|
feats (tuple[Tensor]): Features from the upstream network, each is |
|
a 4D-tensor. |
|
|
|
Returns: |
|
tuple: Usually contain classification scores and bbox predictions. |
|
|
|
- cls_scores (list[Tensor]): Box scores for each scale level, \ |
|
each is a 4D-tensor, the channel number is \ |
|
num_points * num_classes. |
|
- bbox_preds (list[Tensor]): Box energies / deltas for each scale \ |
|
level, each is a 4D-tensor, the channel number is num_points * 4. |
|
""" |
|
return multi_apply(self.forward_single, x)[:2] |
|
|
|
def forward_single(self, x: Tensor) -> Tuple[Tensor, ...]: |
|
"""Forward features of a single scale level. |
|
|
|
Args: |
|
x (Tensor): FPN feature maps of the specified stride. |
|
|
|
Returns: |
|
tuple: Scores for each class, bbox predictions, features |
|
after classification and regression conv layers, some |
|
models needs these features like FCOS. |
|
""" |
|
cls_feat = x |
|
reg_feat = x |
|
|
|
for cls_layer in self.cls_convs: |
|
cls_feat = cls_layer(cls_feat) |
|
cls_score = self.conv_cls(cls_feat) |
|
|
|
for reg_layer in self.reg_convs: |
|
reg_feat = reg_layer(reg_feat) |
|
bbox_pred = self.conv_reg(reg_feat) |
|
return cls_score, bbox_pred, cls_feat, reg_feat |
|
|
|
@abstractmethod |
|
def loss_by_feat( |
|
self, |
|
cls_scores: List[Tensor], |
|
bbox_preds: List[Tensor], |
|
batch_gt_instances: InstanceList, |
|
batch_img_metas: List[dict], |
|
batch_gt_instances_ignore: OptInstanceList = None) -> dict: |
|
"""Calculate the loss based on the features extracted by the detection |
|
head. |
|
|
|
Args: |
|
cls_scores (list[Tensor]): Box scores for each scale level, |
|
each is a 4D-tensor, the channel number is |
|
num_points * num_classes. |
|
bbox_preds (list[Tensor]): Box energies / deltas for each scale |
|
level, each is a 4D-tensor, the channel number is |
|
num_points * 4. |
|
batch_gt_instances (list[:obj:`InstanceData`]): Batch of |
|
gt_instance. It usually includes ``bboxes`` and ``labels`` |
|
attributes. |
|
batch_img_metas (list[dict]): Meta information of each image, e.g., |
|
image size, scaling factor, etc. |
|
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): |
|
Batch of gt_instances_ignore. It includes ``bboxes`` attribute |
|
data that is ignored during training and testing. |
|
Defaults to None. |
|
""" |
|
|
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def get_targets(self, points: List[Tensor], |
|
batch_gt_instances: InstanceList) -> Any: |
|
"""Compute regression, classification and centerness targets for points |
|
in multiple images. |
|
|
|
Args: |
|
points (list[Tensor]): Points of each fpn level, each has shape |
|
(num_points, 2). |
|
batch_gt_instances (list[:obj:`InstanceData`]): Batch of |
|
gt_instance. It usually includes ``bboxes`` and ``labels`` |
|
attributes. |
|
""" |
|
raise NotImplementedError |
|
|
|
|
|
def aug_test(self, |
|
aug_batch_feats: List[Tensor], |
|
aug_batch_img_metas: List[List[Tensor]], |
|
rescale: bool = False) -> List[ndarray]: |
|
"""Test function with test time augmentation. |
|
|
|
Args: |
|
aug_batch_feats (list[Tensor]): the outer list indicates test-time |
|
augmentations and inner Tensor should have a shape NxCxHxW, |
|
which contains features for all images in the batch. |
|
aug_batch_img_metas (list[list[dict]]): the outer list indicates |
|
test-time augs (multiscale, flip, etc.) and the inner list |
|
indicates images in a batch. each dict has image information. |
|
rescale (bool, optional): Whether to rescale the results. |
|
Defaults to False. |
|
|
|
Returns: |
|
list[ndarray]: bbox results of each class |
|
""" |
|
return self.aug_test_bboxes( |
|
aug_batch_feats, aug_batch_img_metas, rescale=rescale) |
|
|