|
|
|
from typing import Dict, List, Sequence, Tuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from mmcv.cnn import Scale |
|
from mmengine.model import bias_init_with_prob, normal_init |
|
from mmengine.structures import InstanceData |
|
from torch import Tensor |
|
|
|
from mmdet.registry import MODELS |
|
from mmdet.structures.bbox import bbox_overlaps |
|
from mmdet.utils import InstanceList, OptInstanceList, reduce_mean |
|
from ..task_modules.prior_generators import MlvlPointGenerator |
|
from ..utils import levels_to_images, multi_apply |
|
from .fcos_head import FCOSHead |
|
|
|
EPS = 1e-12 |
|
|
|
|
|
class CenterPrior(nn.Module): |
|
"""Center Weighting module to adjust the category-specific prior |
|
distributions. |
|
|
|
Args: |
|
force_topk (bool): When no point falls into gt_bbox, forcibly |
|
select the k points closest to the center to calculate |
|
the center prior. Defaults to False. |
|
topk (int): The number of points used to calculate the |
|
center prior when no point falls in gt_bbox. Only work when |
|
force_topk if True. Defaults to 9. |
|
num_classes (int): The class number of dataset. Defaults to 80. |
|
strides (Sequence[int]): The stride of each input feature map. |
|
Defaults to (8, 16, 32, 64, 128). |
|
""" |
|
|
|
def __init__( |
|
self, |
|
force_topk: bool = False, |
|
topk: int = 9, |
|
num_classes: int = 80, |
|
strides: Sequence[int] = (8, 16, 32, 64, 128) |
|
) -> None: |
|
super().__init__() |
|
self.mean = nn.Parameter(torch.zeros(num_classes, 2)) |
|
self.sigma = nn.Parameter(torch.ones(num_classes, 2)) |
|
self.strides = strides |
|
self.force_topk = force_topk |
|
self.topk = topk |
|
|
|
def forward(self, anchor_points_list: List[Tensor], |
|
gt_instances: InstanceData, |
|
inside_gt_bbox_mask: Tensor) -> Tuple[Tensor, Tensor]: |
|
"""Get the center prior of each point on the feature map for each |
|
instance. |
|
|
|
Args: |
|
anchor_points_list (list[Tensor]): list of coordinate |
|
of points on feature map. Each with shape |
|
(num_points, 2). |
|
gt_instances (:obj:`InstanceData`): Ground truth of instance |
|
annotations. It should includes ``bboxes`` and ``labels`` |
|
attributes. |
|
inside_gt_bbox_mask (Tensor): Tensor of bool type, |
|
with shape of (num_points, num_gt), each |
|
value is used to mark whether this point falls |
|
within a certain gt. |
|
|
|
Returns: |
|
tuple[Tensor, Tensor]: |
|
|
|
- center_prior_weights(Tensor): Float tensor with shape of \ |
|
(num_points, num_gt). Each value represents the center \ |
|
weighting coefficient. |
|
- inside_gt_bbox_mask (Tensor): Tensor of bool type, with shape \ |
|
of (num_points, num_gt), each value is used to mark whether this \ |
|
point falls within a certain gt or is the topk nearest points for \ |
|
a specific gt_bbox. |
|
""" |
|
gt_bboxes = gt_instances.bboxes |
|
labels = gt_instances.labels |
|
|
|
inside_gt_bbox_mask = inside_gt_bbox_mask.clone() |
|
num_gts = len(labels) |
|
num_points = sum([len(item) for item in anchor_points_list]) |
|
if num_gts == 0: |
|
return gt_bboxes.new_zeros(num_points, |
|
num_gts), inside_gt_bbox_mask |
|
center_prior_list = [] |
|
for slvl_points, stride in zip(anchor_points_list, self.strides): |
|
|
|
|
|
single_level_points = slvl_points[:, None, :].expand( |
|
(slvl_points.size(0), len(gt_bboxes), 2)) |
|
gt_center_x = ((gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2) |
|
gt_center_y = ((gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2) |
|
gt_center = torch.stack((gt_center_x, gt_center_y), dim=1) |
|
gt_center = gt_center[None] |
|
|
|
instance_center = self.mean[labels][None] |
|
|
|
instance_sigma = self.sigma[labels][None] |
|
|
|
distance = (((single_level_points - gt_center) / float(stride) - |
|
instance_center)**2) |
|
center_prior = torch.exp(-distance / |
|
(2 * instance_sigma**2)).prod(dim=-1) |
|
center_prior_list.append(center_prior) |
|
center_prior_weights = torch.cat(center_prior_list, dim=0) |
|
|
|
if self.force_topk: |
|
gt_inds_no_points_inside = torch.nonzero( |
|
inside_gt_bbox_mask.sum(0) == 0).reshape(-1) |
|
if gt_inds_no_points_inside.numel(): |
|
topk_center_index = \ |
|
center_prior_weights[:, gt_inds_no_points_inside].topk( |
|
self.topk, |
|
dim=0)[1] |
|
temp_mask = inside_gt_bbox_mask[:, gt_inds_no_points_inside] |
|
inside_gt_bbox_mask[:, gt_inds_no_points_inside] = \ |
|
torch.scatter(temp_mask, |
|
dim=0, |
|
index=topk_center_index, |
|
src=torch.ones_like( |
|
topk_center_index, |
|
dtype=torch.bool)) |
|
|
|
center_prior_weights[~inside_gt_bbox_mask] = 0 |
|
return center_prior_weights, inside_gt_bbox_mask |
|
|
|
|
|
@MODELS.register_module() |
|
class AutoAssignHead(FCOSHead): |
|
"""AutoAssignHead head used in AutoAssign. |
|
|
|
More details can be found in the `paper |
|
<https://arxiv.org/abs/2007.03496>`_ . |
|
|
|
Args: |
|
force_topk (bool): Used in center prior initialization to |
|
handle extremely small gt. Default is False. |
|
topk (int): The number of points used to calculate the |
|
center prior when no point falls in gt_bbox. Only work when |
|
force_topk if True. Defaults to 9. |
|
pos_loss_weight (float): The loss weight of positive loss |
|
and with default value 0.25. |
|
neg_loss_weight (float): The loss weight of negative loss |
|
and with default value 0.75. |
|
center_loss_weight (float): The loss weight of center prior |
|
loss and with default value 0.75. |
|
""" |
|
|
|
def __init__(self, |
|
*args, |
|
force_topk: bool = False, |
|
topk: int = 9, |
|
pos_loss_weight: float = 0.25, |
|
neg_loss_weight: float = 0.75, |
|
center_loss_weight: float = 0.75, |
|
**kwargs) -> None: |
|
super().__init__(*args, conv_bias=True, **kwargs) |
|
self.center_prior = CenterPrior( |
|
force_topk=force_topk, |
|
topk=topk, |
|
num_classes=self.num_classes, |
|
strides=self.strides) |
|
self.pos_loss_weight = pos_loss_weight |
|
self.neg_loss_weight = neg_loss_weight |
|
self.center_loss_weight = center_loss_weight |
|
self.prior_generator = MlvlPointGenerator(self.strides, offset=0) |
|
|
|
def init_weights(self) -> None: |
|
"""Initialize weights of the head. |
|
|
|
In particular, we have special initialization for classified conv's and |
|
regression conv's bias |
|
""" |
|
|
|
super(AutoAssignHead, self).init_weights() |
|
bias_cls = bias_init_with_prob(0.02) |
|
normal_init(self.conv_cls, std=0.01, bias=bias_cls) |
|
normal_init(self.conv_reg, std=0.01, bias=4.0) |
|
|
|
def forward_single(self, x: Tensor, scale: Scale, |
|
stride: int) -> Tuple[Tensor, Tensor, Tensor]: |
|
"""Forward features of a single scale level. |
|
|
|
Args: |
|
x (Tensor): FPN feature maps of the specified stride. |
|
scale (:obj:`mmcv.cnn.Scale`): Learnable scale module to resize |
|
the bbox prediction. |
|
stride (int): The corresponding stride for feature maps, only |
|
used to normalize the bbox prediction when self.norm_on_bbox |
|
is True. |
|
|
|
Returns: |
|
tuple[Tensor, Tensor, Tensor]: scores for each class, bbox |
|
predictions and centerness predictions of input feature maps. |
|
""" |
|
cls_score, bbox_pred, cls_feat, reg_feat = super( |
|
FCOSHead, self).forward_single(x) |
|
centerness = self.conv_centerness(reg_feat) |
|
|
|
|
|
bbox_pred = scale(bbox_pred).float() |
|
|
|
|
|
|
|
bbox_pred = bbox_pred.clamp(min=0) |
|
bbox_pred *= stride |
|
return cls_score, bbox_pred, centerness |
|
|
|
def get_pos_loss_single(self, cls_score: Tensor, objectness: Tensor, |
|
reg_loss: Tensor, gt_instances: InstanceData, |
|
center_prior_weights: Tensor) -> Tuple[Tensor]: |
|
"""Calculate the positive loss of all points in gt_bboxes. |
|
|
|
Args: |
|
cls_score (Tensor): All category scores for each point on |
|
the feature map. The shape is (num_points, num_class). |
|
objectness (Tensor): Foreground probability of all points, |
|
has shape (num_points, 1). |
|
reg_loss (Tensor): The regression loss of each gt_bbox and each |
|
prediction box, has shape of (num_points, num_gt). |
|
gt_instances (:obj:`InstanceData`): Ground truth of instance |
|
annotations. It should includes ``bboxes`` and ``labels`` |
|
attributes. |
|
center_prior_weights (Tensor): Float tensor with shape |
|
of (num_points, num_gt). Each value represents |
|
the center weighting coefficient. |
|
|
|
Returns: |
|
tuple[Tensor]: |
|
|
|
- pos_loss (Tensor): The positive loss of all points in the \ |
|
gt_bboxes. |
|
""" |
|
gt_labels = gt_instances.labels |
|
|
|
p_loc = torch.exp(-reg_loss) |
|
|
|
p_cls = (cls_score * objectness)[:, gt_labels] |
|
|
|
p_pos = p_cls * p_loc |
|
|
|
|
|
|
|
confidence_weight = torch.exp(p_pos * 3) |
|
p_pos_weight = (confidence_weight * center_prior_weights) / ( |
|
(confidence_weight * center_prior_weights).sum( |
|
0, keepdim=True)).clamp(min=EPS) |
|
reweighted_p_pos = (p_pos * p_pos_weight).sum(0) |
|
pos_loss = F.binary_cross_entropy( |
|
reweighted_p_pos, |
|
torch.ones_like(reweighted_p_pos), |
|
reduction='none') |
|
pos_loss = pos_loss.sum() * self.pos_loss_weight |
|
return pos_loss, |
|
|
|
def get_neg_loss_single(self, cls_score: Tensor, objectness: Tensor, |
|
gt_instances: InstanceData, ious: Tensor, |
|
inside_gt_bbox_mask: Tensor) -> Tuple[Tensor]: |
|
"""Calculate the negative loss of all points in feature map. |
|
|
|
Args: |
|
cls_score (Tensor): All category scores for each point on |
|
the feature map. The shape is (num_points, num_class). |
|
objectness (Tensor): Foreground probability of all points |
|
and is shape of (num_points, 1). |
|
gt_instances (:obj:`InstanceData`): Ground truth of instance |
|
annotations. It should includes ``bboxes`` and ``labels`` |
|
attributes. |
|
ious (Tensor): Float tensor with shape of (num_points, num_gt). |
|
Each value represent the iou of pred_bbox and gt_bboxes. |
|
inside_gt_bbox_mask (Tensor): Tensor of bool type, |
|
with shape of (num_points, num_gt), each |
|
value is used to mark whether this point falls |
|
within a certain gt. |
|
|
|
Returns: |
|
tuple[Tensor]: |
|
|
|
- neg_loss (Tensor): The negative loss of all points in the \ |
|
feature map. |
|
""" |
|
gt_labels = gt_instances.labels |
|
num_gts = len(gt_labels) |
|
joint_conf = (cls_score * objectness) |
|
p_neg_weight = torch.ones_like(joint_conf) |
|
if num_gts > 0: |
|
|
|
|
|
|
|
inside_gt_bbox_mask = inside_gt_bbox_mask.permute(1, 0) |
|
ious = ious.permute(1, 0) |
|
|
|
foreground_idxs = torch.nonzero(inside_gt_bbox_mask, as_tuple=True) |
|
temp_weight = (1 / (1 - ious[foreground_idxs]).clamp_(EPS)) |
|
|
|
def normalize(x): |
|
return (x - x.min() + EPS) / (x.max() - x.min() + EPS) |
|
|
|
for instance_idx in range(num_gts): |
|
idxs = foreground_idxs[0] == instance_idx |
|
if idxs.any(): |
|
temp_weight[idxs] = normalize(temp_weight[idxs]) |
|
|
|
p_neg_weight[foreground_idxs[1], |
|
gt_labels[foreground_idxs[0]]] = 1 - temp_weight |
|
|
|
logits = (joint_conf * p_neg_weight) |
|
neg_loss = ( |
|
logits**2 * F.binary_cross_entropy( |
|
logits, torch.zeros_like(logits), reduction='none')) |
|
neg_loss = neg_loss.sum() * self.neg_loss_weight |
|
return neg_loss, |
|
|
|
def loss_by_feat( |
|
self, |
|
cls_scores: List[Tensor], |
|
bbox_preds: List[Tensor], |
|
objectnesses: List[Tensor], |
|
batch_gt_instances: InstanceList, |
|
batch_img_metas: List[dict], |
|
batch_gt_instances_ignore: OptInstanceList = None |
|
) -> Dict[str, Tensor]: |
|
"""Calculate the loss based on the features extracted by the detection |
|
head. |
|
|
|
Args: |
|
cls_scores (list[Tensor]): Box scores for each scale level, |
|
each is a 4D-tensor, the channel number is |
|
num_points * num_classes. |
|
bbox_preds (list[Tensor]): Box energies / deltas for each scale |
|
level, each is a 4D-tensor, the channel number is |
|
num_points * 4. |
|
objectnesses (list[Tensor]): objectness for each scale level, each |
|
is a 4D-tensor, the channel number is num_points * 1. |
|
batch_gt_instances (list[:obj:`InstanceData`]): Batch of |
|
gt_instance. It usually includes ``bboxes`` and ``labels`` |
|
attributes. |
|
batch_img_metas (list[dict]): Meta information of each image, e.g., |
|
image size, scaling factor, etc. |
|
batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): |
|
Batch of gt_instances_ignore. It includes ``bboxes`` attribute |
|
data that is ignored during training and testing. |
|
Defaults to None. |
|
|
|
Returns: |
|
dict[str, Tensor]: A dictionary of loss components. |
|
""" |
|
|
|
assert len(cls_scores) == len(bbox_preds) == len(objectnesses) |
|
all_num_gt = sum([len(item) for item in batch_gt_instances]) |
|
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] |
|
all_level_points = self.prior_generator.grid_priors( |
|
featmap_sizes, |
|
dtype=bbox_preds[0].dtype, |
|
device=bbox_preds[0].device) |
|
inside_gt_bbox_mask_list, bbox_targets_list = self.get_targets( |
|
all_level_points, batch_gt_instances) |
|
|
|
center_prior_weight_list = [] |
|
temp_inside_gt_bbox_mask_list = [] |
|
for gt_instances, inside_gt_bbox_mask in zip(batch_gt_instances, |
|
inside_gt_bbox_mask_list): |
|
center_prior_weight, inside_gt_bbox_mask = \ |
|
self.center_prior(all_level_points, gt_instances, |
|
inside_gt_bbox_mask) |
|
center_prior_weight_list.append(center_prior_weight) |
|
temp_inside_gt_bbox_mask_list.append(inside_gt_bbox_mask) |
|
inside_gt_bbox_mask_list = temp_inside_gt_bbox_mask_list |
|
mlvl_points = torch.cat(all_level_points, dim=0) |
|
bbox_preds = levels_to_images(bbox_preds) |
|
cls_scores = levels_to_images(cls_scores) |
|
objectnesses = levels_to_images(objectnesses) |
|
|
|
reg_loss_list = [] |
|
ious_list = [] |
|
num_points = len(mlvl_points) |
|
|
|
for bbox_pred, encoded_targets, inside_gt_bbox_mask in zip( |
|
bbox_preds, bbox_targets_list, inside_gt_bbox_mask_list): |
|
temp_num_gt = encoded_targets.size(1) |
|
expand_mlvl_points = mlvl_points[:, None, :].expand( |
|
num_points, temp_num_gt, 2).reshape(-1, 2) |
|
encoded_targets = encoded_targets.reshape(-1, 4) |
|
expand_bbox_pred = bbox_pred[:, None, :].expand( |
|
num_points, temp_num_gt, 4).reshape(-1, 4) |
|
decoded_bbox_preds = self.bbox_coder.decode( |
|
expand_mlvl_points, expand_bbox_pred) |
|
decoded_target_preds = self.bbox_coder.decode( |
|
expand_mlvl_points, encoded_targets) |
|
with torch.no_grad(): |
|
ious = bbox_overlaps( |
|
decoded_bbox_preds, decoded_target_preds, is_aligned=True) |
|
ious = ious.reshape(num_points, temp_num_gt) |
|
if temp_num_gt: |
|
ious = ious.max( |
|
dim=-1, keepdim=True).values.repeat(1, temp_num_gt) |
|
else: |
|
ious = ious.new_zeros(num_points, temp_num_gt) |
|
ious[~inside_gt_bbox_mask] = 0 |
|
ious_list.append(ious) |
|
loss_bbox = self.loss_bbox( |
|
decoded_bbox_preds, |
|
decoded_target_preds, |
|
weight=None, |
|
reduction_override='none') |
|
reg_loss_list.append(loss_bbox.reshape(num_points, temp_num_gt)) |
|
|
|
cls_scores = [item.sigmoid() for item in cls_scores] |
|
objectnesses = [item.sigmoid() for item in objectnesses] |
|
pos_loss_list, = multi_apply(self.get_pos_loss_single, cls_scores, |
|
objectnesses, reg_loss_list, |
|
batch_gt_instances, |
|
center_prior_weight_list) |
|
pos_avg_factor = reduce_mean( |
|
bbox_pred.new_tensor(all_num_gt)).clamp_(min=1) |
|
pos_loss = sum(pos_loss_list) / pos_avg_factor |
|
|
|
neg_loss_list, = multi_apply(self.get_neg_loss_single, cls_scores, |
|
objectnesses, batch_gt_instances, |
|
ious_list, inside_gt_bbox_mask_list) |
|
neg_avg_factor = sum(item.data.sum() |
|
for item in center_prior_weight_list) |
|
neg_avg_factor = reduce_mean(neg_avg_factor).clamp_(min=1) |
|
neg_loss = sum(neg_loss_list) / neg_avg_factor |
|
|
|
center_loss = [] |
|
for i in range(len(batch_img_metas)): |
|
|
|
if inside_gt_bbox_mask_list[i].any(): |
|
center_loss.append( |
|
len(batch_gt_instances[i]) / |
|
center_prior_weight_list[i].sum().clamp_(min=EPS)) |
|
|
|
else: |
|
center_loss.append(center_prior_weight_list[i].sum() * 0) |
|
|
|
center_loss = torch.stack(center_loss).mean() * self.center_loss_weight |
|
|
|
|
|
if all_num_gt == 0: |
|
pos_loss = bbox_preds[0].sum() * 0 |
|
dummy_center_prior_loss = self.center_prior.mean.sum( |
|
) * 0 + self.center_prior.sigma.sum() * 0 |
|
center_loss = objectnesses[0].sum() * 0 + dummy_center_prior_loss |
|
|
|
loss = dict( |
|
loss_pos=pos_loss, loss_neg=neg_loss, loss_center=center_loss) |
|
|
|
return loss |
|
|
|
def get_targets( |
|
self, points: List[Tensor], batch_gt_instances: InstanceList |
|
) -> Tuple[List[Tensor], List[Tensor]]: |
|
"""Compute regression targets and each point inside or outside gt_bbox |
|
in multiple images. |
|
|
|
Args: |
|
points (list[Tensor]): Points of all fpn level, each has shape |
|
(num_points, 2). |
|
batch_gt_instances (list[:obj:`InstanceData`]): Batch of |
|
gt_instance. It usually includes ``bboxes`` and ``labels`` |
|
attributes. |
|
|
|
Returns: |
|
tuple(list[Tensor], list[Tensor]): |
|
|
|
- inside_gt_bbox_mask_list (list[Tensor]): Each Tensor is with \ |
|
bool type and shape of (num_points, num_gt), each value is used \ |
|
to mark whether this point falls within a certain gt. |
|
- concat_lvl_bbox_targets (list[Tensor]): BBox targets of each \ |
|
level. Each tensor has shape (num_points, num_gt, 4). |
|
""" |
|
|
|
concat_points = torch.cat(points, dim=0) |
|
|
|
inside_gt_bbox_mask_list, bbox_targets_list = multi_apply( |
|
self._get_targets_single, batch_gt_instances, points=concat_points) |
|
return inside_gt_bbox_mask_list, bbox_targets_list |
|
|
|
def _get_targets_single(self, gt_instances: InstanceData, |
|
points: Tensor) -> Tuple[Tensor, Tensor]: |
|
"""Compute regression targets and each point inside or outside gt_bbox |
|
for a single image. |
|
|
|
Args: |
|
gt_instances (:obj:`InstanceData`): Ground truth of instance |
|
annotations. It should includes ``bboxes`` and ``labels`` |
|
attributes. |
|
points (Tensor): Points of all fpn level, has shape |
|
(num_points, 2). |
|
|
|
Returns: |
|
tuple[Tensor, Tensor]: Containing the following Tensors: |
|
|
|
- inside_gt_bbox_mask (Tensor): Bool tensor with shape \ |
|
(num_points, num_gt), each value is used to mark whether this \ |
|
point falls within a certain gt. |
|
- bbox_targets (Tensor): BBox targets of each points with each \ |
|
gt_bboxes, has shape (num_points, num_gt, 4). |
|
""" |
|
gt_bboxes = gt_instances.bboxes |
|
num_points = points.size(0) |
|
num_gts = gt_bboxes.size(0) |
|
gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4) |
|
xs, ys = points[:, 0], points[:, 1] |
|
xs = xs[:, None] |
|
ys = ys[:, None] |
|
left = xs - gt_bboxes[..., 0] |
|
right = gt_bboxes[..., 2] - xs |
|
top = ys - gt_bboxes[..., 1] |
|
bottom = gt_bboxes[..., 3] - ys |
|
bbox_targets = torch.stack((left, top, right, bottom), -1) |
|
if num_gts: |
|
inside_gt_bbox_mask = bbox_targets.min(-1)[0] > 0 |
|
else: |
|
inside_gt_bbox_mask = bbox_targets.new_zeros((num_points, num_gts), |
|
dtype=torch.bool) |
|
|
|
return inside_gt_bbox_mask, bbox_targets |
|
|