|
|
|
from typing import Union |
|
|
|
import torch |
|
from numpy import ndarray |
|
from torch import Tensor |
|
|
|
from mmdet.registry import TASK_UTILS |
|
from ..assigners import AssignResult |
|
from .base_sampler import BaseSampler |
|
|
|
|
|
@TASK_UTILS.register_module() |
|
class RandomSampler(BaseSampler): |
|
"""Random sampler. |
|
|
|
Args: |
|
num (int): Number of samples |
|
pos_fraction (float): Fraction of positive samples |
|
neg_pos_up (int): Upper bound number of negative and |
|
positive samples. Defaults to -1. |
|
add_gt_as_proposals (bool): Whether to add ground truth |
|
boxes as proposals. Defaults to True. |
|
""" |
|
|
|
def __init__(self, |
|
num: int, |
|
pos_fraction: float, |
|
neg_pos_ub: int = -1, |
|
add_gt_as_proposals: bool = True, |
|
**kwargs): |
|
from .sampling_result import ensure_rng |
|
super().__init__( |
|
num=num, |
|
pos_fraction=pos_fraction, |
|
neg_pos_ub=neg_pos_ub, |
|
add_gt_as_proposals=add_gt_as_proposals) |
|
self.rng = ensure_rng(kwargs.get('rng', None)) |
|
|
|
def random_choice(self, gallery: Union[Tensor, ndarray, list], |
|
num: int) -> Union[Tensor, ndarray]: |
|
"""Random select some elements from the gallery. |
|
|
|
If `gallery` is a Tensor, the returned indices will be a Tensor; |
|
If `gallery` is a ndarray or list, the returned indices will be a |
|
ndarray. |
|
|
|
Args: |
|
gallery (Tensor | ndarray | list): indices pool. |
|
num (int): expected sample num. |
|
|
|
Returns: |
|
Tensor or ndarray: sampled indices. |
|
""" |
|
assert len(gallery) >= num |
|
|
|
is_tensor = isinstance(gallery, torch.Tensor) |
|
if not is_tensor: |
|
if torch.cuda.is_available(): |
|
device = torch.cuda.current_device() |
|
else: |
|
device = 'cpu' |
|
gallery = torch.tensor(gallery, dtype=torch.long, device=device) |
|
|
|
|
|
|
|
perm = torch.randperm(gallery.numel())[:num].to(device=gallery.device) |
|
rand_inds = gallery[perm] |
|
if not is_tensor: |
|
rand_inds = rand_inds.cpu().numpy() |
|
return rand_inds |
|
|
|
def _sample_pos(self, assign_result: AssignResult, num_expected: int, |
|
**kwargs) -> Union[Tensor, ndarray]: |
|
"""Randomly sample some positive samples. |
|
|
|
Args: |
|
assign_result (:obj:`AssignResult`): Bbox assigning results. |
|
num_expected (int): The number of expected positive samples |
|
|
|
Returns: |
|
Tensor or ndarray: sampled indices. |
|
""" |
|
pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False) |
|
if pos_inds.numel() != 0: |
|
pos_inds = pos_inds.squeeze(1) |
|
if pos_inds.numel() <= num_expected: |
|
return pos_inds |
|
else: |
|
return self.random_choice(pos_inds, num_expected) |
|
|
|
def _sample_neg(self, assign_result: AssignResult, num_expected: int, |
|
**kwargs) -> Union[Tensor, ndarray]: |
|
"""Randomly sample some negative samples. |
|
|
|
Args: |
|
assign_result (:obj:`AssignResult`): Bbox assigning results. |
|
num_expected (int): The number of expected positive samples |
|
|
|
Returns: |
|
Tensor or ndarray: sampled indices. |
|
""" |
|
neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False) |
|
if neg_inds.numel() != 0: |
|
neg_inds = neg_inds.squeeze(1) |
|
if len(neg_inds) <= num_expected: |
|
return neg_inds |
|
else: |
|
return self.random_choice(neg_inds, num_expected) |
|
|