|
|
|
from typing import Optional |
|
|
|
import torch |
|
from mmengine.structures import InstanceData |
|
|
|
from mmdet.registry import TASK_UTILS |
|
from .assign_result import AssignResult |
|
from .max_iou_assigner import MaxIoUAssigner |
|
|
|
|
|
@TASK_UTILS.register_module() |
|
class MultiInstanceAssigner(MaxIoUAssigner): |
|
"""Assign a corresponding gt bbox or background to each proposal bbox. If |
|
we need to use a proposal box to generate multiple predict boxes, |
|
`MultiInstanceAssigner` can assign multiple gt to each proposal box. |
|
|
|
Args: |
|
num_instance (int): How many bboxes are predicted by each proposal box. |
|
""" |
|
|
|
def __init__(self, num_instance: int = 2, **kwargs): |
|
super().__init__(**kwargs) |
|
self.num_instance = num_instance |
|
|
|
def assign(self, |
|
pred_instances: InstanceData, |
|
gt_instances: InstanceData, |
|
gt_instances_ignore: Optional[InstanceData] = None, |
|
**kwargs) -> AssignResult: |
|
"""Assign gt to bboxes. |
|
|
|
This method assign gt bboxes to every bbox (proposal/anchor), each bbox |
|
is assigned a set of gts, and the number of gts in this set is defined |
|
by `self.num_instance`. |
|
|
|
Args: |
|
pred_instances (:obj:`InstanceData`): Instances of model |
|
predictions. It includes ``priors``, and the priors can |
|
be anchors or points, or the bboxes predicted by the |
|
previous stage, has shape (n, 4). The bboxes predicted by |
|
the current model or stage will be named ``bboxes``, |
|
``labels``, and ``scores``, the same as the ``InstanceData`` |
|
in other places. |
|
gt_instances (:obj:`InstanceData`): Ground truth of instance |
|
annotations. It usually includes ``bboxes``, with shape (k, 4), |
|
and ``labels``, with shape (k, ). |
|
gt_instances_ignore (:obj:`InstanceData`, optional): Instances |
|
to be ignored during training. It includes ``bboxes`` |
|
attribute data that is ignored during training and testing. |
|
Defaults to None. |
|
|
|
Returns: |
|
:obj:`AssignResult`: The assign result. |
|
""" |
|
gt_bboxes = gt_instances.bboxes |
|
priors = pred_instances.priors |
|
|
|
gt_labels = gt_instances.labels + 1 |
|
if gt_instances_ignore is not None: |
|
gt_bboxes_ignore = gt_instances_ignore.bboxes |
|
if hasattr(gt_instances_ignore, 'labels'): |
|
gt_labels_ignore = gt_instances_ignore.labels |
|
else: |
|
gt_labels_ignore = torch.ones_like(gt_bboxes_ignore)[:, 0] * -1 |
|
else: |
|
gt_bboxes_ignore = None |
|
gt_labels_ignore = None |
|
|
|
assign_on_cpu = True if (self.gpu_assign_thr > 0) and ( |
|
gt_bboxes.shape[0] > self.gpu_assign_thr) else False |
|
|
|
if assign_on_cpu: |
|
device = priors.device |
|
priors = priors.cpu() |
|
gt_bboxes = gt_bboxes.cpu() |
|
gt_labels = gt_labels.cpu() |
|
if gt_bboxes_ignore is not None: |
|
gt_bboxes_ignore = gt_bboxes_ignore.cpu() |
|
gt_labels_ignore = gt_labels_ignore.cpu() |
|
|
|
if gt_bboxes_ignore is not None: |
|
all_bboxes = torch.cat([gt_bboxes, gt_bboxes_ignore], dim=0) |
|
all_labels = torch.cat([gt_labels, gt_labels_ignore], dim=0) |
|
else: |
|
all_bboxes = gt_bboxes |
|
all_labels = gt_labels |
|
all_priors = torch.cat([priors, all_bboxes], dim=0) |
|
|
|
overlaps_normal = self.iou_calculator( |
|
all_priors, all_bboxes, mode='iou') |
|
overlaps_ignore = self.iou_calculator( |
|
all_priors, all_bboxes, mode='iof') |
|
gt_ignore_mask = all_labels.eq(-1).repeat(all_priors.shape[0], 1) |
|
overlaps_normal = overlaps_normal * ~gt_ignore_mask |
|
overlaps_ignore = overlaps_ignore * gt_ignore_mask |
|
|
|
overlaps_normal, overlaps_normal_indices = overlaps_normal.sort( |
|
descending=True, dim=1) |
|
overlaps_ignore, overlaps_ignore_indices = overlaps_ignore.sort( |
|
descending=True, dim=1) |
|
|
|
|
|
max_overlaps_normal = overlaps_normal[:, :self.num_instance].flatten() |
|
gt_assignment_normal = overlaps_normal_indices[:, :self. |
|
num_instance].flatten() |
|
max_overlaps_ignore = overlaps_ignore[:, :self.num_instance].flatten() |
|
gt_assignment_ignore = overlaps_ignore_indices[:, :self. |
|
num_instance].flatten() |
|
|
|
|
|
ignore_assign_mask = (max_overlaps_normal < self.pos_iou_thr) * ( |
|
max_overlaps_ignore > max_overlaps_normal) |
|
overlaps = (max_overlaps_normal * ~ignore_assign_mask) + ( |
|
max_overlaps_ignore * ignore_assign_mask) |
|
gt_assignment = (gt_assignment_normal * ~ignore_assign_mask) + ( |
|
gt_assignment_ignore * ignore_assign_mask) |
|
|
|
assigned_labels = all_labels[gt_assignment] |
|
fg_mask = (overlaps >= self.pos_iou_thr) * (assigned_labels != -1) |
|
bg_mask = (overlaps < self.neg_iou_thr) * (overlaps >= 0) |
|
assigned_labels[fg_mask] = 1 |
|
assigned_labels[bg_mask] = 0 |
|
|
|
overlaps = overlaps.reshape(-1, self.num_instance) |
|
gt_assignment = gt_assignment.reshape(-1, self.num_instance) |
|
assigned_labels = assigned_labels.reshape(-1, self.num_instance) |
|
|
|
assign_result = AssignResult( |
|
num_gts=all_bboxes.size(0), |
|
gt_inds=gt_assignment, |
|
max_overlaps=overlaps, |
|
labels=assigned_labels) |
|
|
|
if assign_on_cpu: |
|
assign_result.gt_inds = assign_result.gt_inds.to(device) |
|
assign_result.max_overlaps = assign_result.max_overlaps.to(device) |
|
if assign_result.labels is not None: |
|
assign_result.labels = assign_result.labels.to(device) |
|
return assign_result |
|
|