|
|
|
from abc import ABCMeta, abstractmethod |
|
|
|
import torch |
|
from mmengine.structures import InstanceData |
|
|
|
from mmdet.structures.bbox import BaseBoxes, cat_boxes |
|
from ..assigners import AssignResult |
|
from .sampling_result import SamplingResult |
|
|
|
|
|
class BaseSampler(metaclass=ABCMeta): |
|
"""Base class of samplers. |
|
|
|
Args: |
|
num (int): Number of samples |
|
pos_fraction (float): Fraction of positive samples |
|
neg_pos_up (int): Upper bound number of negative and |
|
positive samples. Defaults to -1. |
|
add_gt_as_proposals (bool): Whether to add ground truth |
|
boxes as proposals. Defaults to True. |
|
""" |
|
|
|
def __init__(self, |
|
num: int, |
|
pos_fraction: float, |
|
neg_pos_ub: int = -1, |
|
add_gt_as_proposals: bool = True, |
|
**kwargs) -> None: |
|
self.num = num |
|
self.pos_fraction = pos_fraction |
|
self.neg_pos_ub = neg_pos_ub |
|
self.add_gt_as_proposals = add_gt_as_proposals |
|
self.pos_sampler = self |
|
self.neg_sampler = self |
|
|
|
@abstractmethod |
|
def _sample_pos(self, assign_result: AssignResult, num_expected: int, |
|
**kwargs): |
|
"""Sample positive samples.""" |
|
pass |
|
|
|
@abstractmethod |
|
def _sample_neg(self, assign_result: AssignResult, num_expected: int, |
|
**kwargs): |
|
"""Sample negative samples.""" |
|
pass |
|
|
|
def sample(self, assign_result: AssignResult, pred_instances: InstanceData, |
|
gt_instances: InstanceData, **kwargs) -> SamplingResult: |
|
"""Sample positive and negative bboxes. |
|
|
|
This is a simple implementation of bbox sampling given candidates, |
|
assigning results and ground truth bboxes. |
|
|
|
Args: |
|
assign_result (:obj:`AssignResult`): Assigning results. |
|
pred_instances (:obj:`InstanceData`): Instances of model |
|
predictions. It includes ``priors``, and the priors can |
|
be anchors or points, or the bboxes predicted by the |
|
previous stage, has shape (n, 4). The bboxes predicted by |
|
the current model or stage will be named ``bboxes``, |
|
``labels``, and ``scores``, the same as the ``InstanceData`` |
|
in other places. |
|
gt_instances (:obj:`InstanceData`): Ground truth of instance |
|
annotations. It usually includes ``bboxes``, with shape (k, 4), |
|
and ``labels``, with shape (k, ). |
|
|
|
Returns: |
|
:obj:`SamplingResult`: Sampling result. |
|
|
|
Example: |
|
>>> from mmengine.structures import InstanceData |
|
>>> from mmdet.models.task_modules.samplers import RandomSampler, |
|
>>> from mmdet.models.task_modules.assigners import AssignResult |
|
>>> from mmdet.models.task_modules.samplers. |
|
... sampling_result import ensure_rng, random_boxes |
|
>>> rng = ensure_rng(None) |
|
>>> assign_result = AssignResult.random(rng=rng) |
|
>>> pred_instances = InstanceData() |
|
>>> pred_instances.priors = random_boxes(assign_result.num_preds, |
|
... rng=rng) |
|
>>> gt_instances = InstanceData() |
|
>>> gt_instances.bboxes = random_boxes(assign_result.num_gts, |
|
... rng=rng) |
|
>>> gt_instances.labels = torch.randint( |
|
... 0, 5, (assign_result.num_gts,), dtype=torch.long) |
|
>>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1, |
|
>>> add_gt_as_proposals=False) |
|
>>> self = self.sample(assign_result, pred_instances, gt_instances) |
|
""" |
|
gt_bboxes = gt_instances.bboxes |
|
priors = pred_instances.priors |
|
gt_labels = gt_instances.labels |
|
if len(priors.shape) < 2: |
|
priors = priors[None, :] |
|
|
|
gt_flags = priors.new_zeros((priors.shape[0], ), dtype=torch.uint8) |
|
if self.add_gt_as_proposals and len(gt_bboxes) > 0: |
|
|
|
|
|
if (isinstance(gt_bboxes, BaseBoxes) |
|
and isinstance(priors, BaseBoxes)): |
|
gt_bboxes_ = gt_bboxes.convert_to(type(priors)) |
|
else: |
|
gt_bboxes_ = gt_bboxes |
|
priors = cat_boxes([gt_bboxes_, priors], dim=0) |
|
assign_result.add_gt_(gt_labels) |
|
gt_ones = priors.new_ones(gt_bboxes_.shape[0], dtype=torch.uint8) |
|
gt_flags = torch.cat([gt_ones, gt_flags]) |
|
|
|
num_expected_pos = int(self.num * self.pos_fraction) |
|
pos_inds = self.pos_sampler._sample_pos( |
|
assign_result, num_expected_pos, bboxes=priors, **kwargs) |
|
|
|
|
|
pos_inds = pos_inds.unique() |
|
num_sampled_pos = pos_inds.numel() |
|
num_expected_neg = self.num - num_sampled_pos |
|
if self.neg_pos_ub >= 0: |
|
_pos = max(1, num_sampled_pos) |
|
neg_upper_bound = int(self.neg_pos_ub * _pos) |
|
if num_expected_neg > neg_upper_bound: |
|
num_expected_neg = neg_upper_bound |
|
neg_inds = self.neg_sampler._sample_neg( |
|
assign_result, num_expected_neg, bboxes=priors, **kwargs) |
|
neg_inds = neg_inds.unique() |
|
|
|
sampling_result = SamplingResult( |
|
pos_inds=pos_inds, |
|
neg_inds=neg_inds, |
|
priors=priors, |
|
gt_bboxes=gt_bboxes, |
|
assign_result=assign_result, |
|
gt_flags=gt_flags) |
|
return sampling_result |
|
|