|
|
|
import numpy as np |
|
import torch |
|
|
|
from mmdet.registry import TASK_UTILS |
|
from .random_sampler import RandomSampler |
|
|
|
|
|
@TASK_UTILS.register_module() |
|
class IoUBalancedNegSampler(RandomSampler): |
|
"""IoU Balanced Sampling. |
|
|
|
arXiv: https://arxiv.org/pdf/1904.02701.pdf (CVPR 2019) |
|
|
|
Sampling proposals according to their IoU. `floor_fraction` of needed RoIs |
|
are sampled from proposals whose IoU are lower than `floor_thr` randomly. |
|
The others are sampled from proposals whose IoU are higher than |
|
`floor_thr`. These proposals are sampled from some bins evenly, which are |
|
split by `num_bins` via IoU evenly. |
|
|
|
Args: |
|
num (int): number of proposals. |
|
pos_fraction (float): fraction of positive proposals. |
|
floor_thr (float): threshold (minimum) IoU for IoU balanced sampling, |
|
set to -1 if all using IoU balanced sampling. |
|
floor_fraction (float): sampling fraction of proposals under floor_thr. |
|
num_bins (int): number of bins in IoU balanced sampling. |
|
""" |
|
|
|
def __init__(self, |
|
num, |
|
pos_fraction, |
|
floor_thr=-1, |
|
floor_fraction=0, |
|
num_bins=3, |
|
**kwargs): |
|
super(IoUBalancedNegSampler, self).__init__(num, pos_fraction, |
|
**kwargs) |
|
assert floor_thr >= 0 or floor_thr == -1 |
|
assert 0 <= floor_fraction <= 1 |
|
assert num_bins >= 1 |
|
|
|
self.floor_thr = floor_thr |
|
self.floor_fraction = floor_fraction |
|
self.num_bins = num_bins |
|
|
|
def sample_via_interval(self, max_overlaps, full_set, num_expected): |
|
"""Sample according to the iou interval. |
|
|
|
Args: |
|
max_overlaps (torch.Tensor): IoU between bounding boxes and ground |
|
truth boxes. |
|
full_set (set(int)): A full set of indices of boxes。 |
|
num_expected (int): Number of expected samples。 |
|
|
|
Returns: |
|
np.ndarray: Indices of samples |
|
""" |
|
max_iou = max_overlaps.max() |
|
iou_interval = (max_iou - self.floor_thr) / self.num_bins |
|
per_num_expected = int(num_expected / self.num_bins) |
|
|
|
sampled_inds = [] |
|
for i in range(self.num_bins): |
|
start_iou = self.floor_thr + i * iou_interval |
|
end_iou = self.floor_thr + (i + 1) * iou_interval |
|
tmp_set = set( |
|
np.where( |
|
np.logical_and(max_overlaps >= start_iou, |
|
max_overlaps < end_iou))[0]) |
|
tmp_inds = list(tmp_set & full_set) |
|
if len(tmp_inds) > per_num_expected: |
|
tmp_sampled_set = self.random_choice(tmp_inds, |
|
per_num_expected) |
|
else: |
|
tmp_sampled_set = np.array(tmp_inds, dtype=np.int64) |
|
sampled_inds.append(tmp_sampled_set) |
|
|
|
sampled_inds = np.concatenate(sampled_inds) |
|
if len(sampled_inds) < num_expected: |
|
num_extra = num_expected - len(sampled_inds) |
|
extra_inds = np.array(list(full_set - set(sampled_inds))) |
|
if len(extra_inds) > num_extra: |
|
extra_inds = self.random_choice(extra_inds, num_extra) |
|
sampled_inds = np.concatenate([sampled_inds, extra_inds]) |
|
|
|
return sampled_inds |
|
|
|
def _sample_neg(self, assign_result, num_expected, **kwargs): |
|
"""Sample negative boxes. |
|
|
|
Args: |
|
assign_result (:obj:`AssignResult`): The assigned results of boxes. |
|
num_expected (int): The number of expected negative samples |
|
|
|
Returns: |
|
Tensor or ndarray: sampled indices. |
|
""" |
|
neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False) |
|
if neg_inds.numel() != 0: |
|
neg_inds = neg_inds.squeeze(1) |
|
if len(neg_inds) <= num_expected: |
|
return neg_inds |
|
else: |
|
max_overlaps = assign_result.max_overlaps.cpu().numpy() |
|
|
|
neg_set = set(neg_inds.cpu().numpy()) |
|
|
|
if self.floor_thr > 0: |
|
floor_set = set( |
|
np.where( |
|
np.logical_and(max_overlaps >= 0, |
|
max_overlaps < self.floor_thr))[0]) |
|
iou_sampling_set = set( |
|
np.where(max_overlaps >= self.floor_thr)[0]) |
|
elif self.floor_thr == 0: |
|
floor_set = set(np.where(max_overlaps == 0)[0]) |
|
iou_sampling_set = set( |
|
np.where(max_overlaps > self.floor_thr)[0]) |
|
else: |
|
floor_set = set() |
|
iou_sampling_set = set( |
|
np.where(max_overlaps > self.floor_thr)[0]) |
|
|
|
self.floor_thr = 0 |
|
|
|
floor_neg_inds = list(floor_set & neg_set) |
|
iou_sampling_neg_inds = list(iou_sampling_set & neg_set) |
|
num_expected_iou_sampling = int(num_expected * |
|
(1 - self.floor_fraction)) |
|
if len(iou_sampling_neg_inds) > num_expected_iou_sampling: |
|
if self.num_bins >= 2: |
|
iou_sampled_inds = self.sample_via_interval( |
|
max_overlaps, set(iou_sampling_neg_inds), |
|
num_expected_iou_sampling) |
|
else: |
|
iou_sampled_inds = self.random_choice( |
|
iou_sampling_neg_inds, num_expected_iou_sampling) |
|
else: |
|
iou_sampled_inds = np.array( |
|
iou_sampling_neg_inds, dtype=np.int64) |
|
num_expected_floor = num_expected - len(iou_sampled_inds) |
|
if len(floor_neg_inds) > num_expected_floor: |
|
sampled_floor_inds = self.random_choice( |
|
floor_neg_inds, num_expected_floor) |
|
else: |
|
sampled_floor_inds = np.array(floor_neg_inds, dtype=np.int64) |
|
sampled_inds = np.concatenate( |
|
(sampled_floor_inds, iou_sampled_inds)) |
|
if len(sampled_inds) < num_expected: |
|
num_extra = num_expected - len(sampled_inds) |
|
extra_inds = np.array(list(neg_set - set(sampled_inds))) |
|
if len(extra_inds) > num_extra: |
|
extra_inds = self.random_choice(extra_inds, num_extra) |
|
sampled_inds = np.concatenate((sampled_inds, extra_inds)) |
|
sampled_inds = torch.from_numpy(sampled_inds).long().to( |
|
assign_result.gt_inds.device) |
|
return sampled_inds |
|
|