|
|
|
from typing import Optional |
|
|
|
import torch |
|
from mmengine.structures import InstanceData |
|
|
|
from mmdet.registry import TASK_UTILS |
|
from .assign_result import AssignResult |
|
from .base_assigner import BaseAssigner |
|
|
|
|
|
@TASK_UTILS.register_module() |
|
class PointAssigner(BaseAssigner): |
|
"""Assign a corresponding gt bbox or background to each point. |
|
|
|
Each proposals will be assigned with `0`, or a positive integer |
|
indicating the ground truth index. |
|
|
|
- 0: negative sample, no assigned gt |
|
- positive integer: positive sample, index (1-based) of assigned gt |
|
""" |
|
|
|
def __init__(self, scale: int = 4, pos_num: int = 3) -> None: |
|
self.scale = scale |
|
self.pos_num = pos_num |
|
|
|
def assign(self, |
|
pred_instances: InstanceData, |
|
gt_instances: InstanceData, |
|
gt_instances_ignore: Optional[InstanceData] = None, |
|
**kwargs) -> AssignResult: |
|
"""Assign gt to points. |
|
|
|
This method assign a gt bbox to every points set, each points set |
|
will be assigned with the background_label (-1), or a label number. |
|
-1 is background, and semi-positive number is the index (0-based) of |
|
assigned gt. |
|
The assignment is done in following steps, the order matters. |
|
|
|
1. assign every points to the background_label (-1) |
|
2. A point is assigned to some gt bbox if |
|
(i) the point is within the k closest points to the gt bbox |
|
(ii) the distance between this point and the gt is smaller than |
|
other gt bboxes |
|
|
|
Args: |
|
pred_instances (:obj:`InstanceData`): Instances of model |
|
predictions. It includes ``priors``, and the priors can |
|
be anchors or points, or the bboxes predicted by the |
|
previous stage, has shape (n, 4). The bboxes predicted by |
|
the current model or stage will be named ``bboxes``, |
|
``labels``, and ``scores``, the same as the ``InstanceData`` |
|
in other places. |
|
|
|
|
|
gt_instances (:obj:`InstanceData`): Ground truth of instance |
|
annotations. It usually includes ``bboxes``, with shape (k, 4), |
|
and ``labels``, with shape (k, ). |
|
gt_instances_ignore (:obj:`InstanceData`, optional): Instances |
|
to be ignored during training. It includes ``bboxes`` |
|
attribute data that is ignored during training and testing. |
|
Defaults to None. |
|
Returns: |
|
:obj:`AssignResult`: The assign result. |
|
""" |
|
gt_bboxes = gt_instances.bboxes |
|
gt_labels = gt_instances.labels |
|
|
|
|
|
points = pred_instances.priors |
|
|
|
num_points = points.shape[0] |
|
num_gts = gt_bboxes.shape[0] |
|
|
|
if num_gts == 0 or num_points == 0: |
|
|
|
assigned_gt_inds = points.new_full((num_points, ), |
|
0, |
|
dtype=torch.long) |
|
assigned_labels = points.new_full((num_points, ), |
|
-1, |
|
dtype=torch.long) |
|
return AssignResult( |
|
num_gts=num_gts, |
|
gt_inds=assigned_gt_inds, |
|
max_overlaps=None, |
|
labels=assigned_labels) |
|
|
|
points_xy = points[:, :2] |
|
points_stride = points[:, 2] |
|
points_lvl = torch.log2( |
|
points_stride).int() |
|
lvl_min, lvl_max = points_lvl.min(), points_lvl.max() |
|
|
|
|
|
gt_bboxes_xy = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) / 2 |
|
gt_bboxes_wh = (gt_bboxes[:, 2:] - gt_bboxes[:, :2]).clamp(min=1e-6) |
|
scale = self.scale |
|
gt_bboxes_lvl = ((torch.log2(gt_bboxes_wh[:, 0] / scale) + |
|
torch.log2(gt_bboxes_wh[:, 1] / scale)) / 2).int() |
|
gt_bboxes_lvl = torch.clamp(gt_bboxes_lvl, min=lvl_min, max=lvl_max) |
|
|
|
|
|
assigned_gt_inds = points.new_zeros((num_points, ), dtype=torch.long) |
|
|
|
assigned_gt_dist = points.new_full((num_points, ), float('inf')) |
|
points_range = torch.arange(points.shape[0]) |
|
|
|
for idx in range(num_gts): |
|
gt_lvl = gt_bboxes_lvl[idx] |
|
|
|
lvl_idx = gt_lvl == points_lvl |
|
points_index = points_range[lvl_idx] |
|
|
|
lvl_points = points_xy[lvl_idx, :] |
|
|
|
gt_point = gt_bboxes_xy[[idx], :] |
|
|
|
gt_wh = gt_bboxes_wh[[idx], :] |
|
|
|
|
|
points_gt_dist = ((lvl_points - gt_point) / gt_wh).norm(dim=1) |
|
|
|
min_dist, min_dist_index = torch.topk( |
|
points_gt_dist, self.pos_num, largest=False) |
|
|
|
min_dist_points_index = points_index[min_dist_index] |
|
|
|
|
|
|
|
|
|
less_than_recorded_index = min_dist < assigned_gt_dist[ |
|
min_dist_points_index] |
|
|
|
|
|
|
|
min_dist_points_index = min_dist_points_index[ |
|
less_than_recorded_index] |
|
|
|
assigned_gt_inds[min_dist_points_index] = idx + 1 |
|
assigned_gt_dist[min_dist_points_index] = min_dist[ |
|
less_than_recorded_index] |
|
|
|
assigned_labels = assigned_gt_inds.new_full((num_points, ), -1) |
|
pos_inds = torch.nonzero( |
|
assigned_gt_inds > 0, as_tuple=False).squeeze() |
|
if pos_inds.numel() > 0: |
|
assigned_labels[pos_inds] = gt_labels[assigned_gt_inds[pos_inds] - |
|
1] |
|
|
|
return AssignResult( |
|
num_gts=num_gts, |
|
gt_inds=assigned_gt_inds, |
|
max_overlaps=None, |
|
labels=assigned_labels) |
|
|