|
|
|
import torch |
|
|
|
from mmdet.registry import TASK_UTILS |
|
from mmdet.structures.bbox import bbox2roi |
|
from .base_sampler import BaseSampler |
|
|
|
|
|
@TASK_UTILS.register_module() |
|
class OHEMSampler(BaseSampler): |
|
r"""Online Hard Example Mining Sampler described in `Training Region-based |
|
Object Detectors with Online Hard Example Mining |
|
<https://arxiv.org/abs/1604.03540>`_. |
|
""" |
|
|
|
def __init__(self, |
|
num, |
|
pos_fraction, |
|
context, |
|
neg_pos_ub=-1, |
|
add_gt_as_proposals=True, |
|
loss_key='loss_cls', |
|
**kwargs): |
|
super(OHEMSampler, self).__init__(num, pos_fraction, neg_pos_ub, |
|
add_gt_as_proposals) |
|
self.context = context |
|
if not hasattr(self.context, 'num_stages'): |
|
self.bbox_head = self.context.bbox_head |
|
else: |
|
self.bbox_head = self.context.bbox_head[self.context.current_stage] |
|
|
|
self.loss_key = loss_key |
|
|
|
def hard_mining(self, inds, num_expected, bboxes, labels, feats): |
|
with torch.no_grad(): |
|
rois = bbox2roi([bboxes]) |
|
if not hasattr(self.context, 'num_stages'): |
|
bbox_results = self.context._bbox_forward(feats, rois) |
|
else: |
|
bbox_results = self.context._bbox_forward( |
|
self.context.current_stage, feats, rois) |
|
cls_score = bbox_results['cls_score'] |
|
loss = self.bbox_head.loss( |
|
cls_score=cls_score, |
|
bbox_pred=None, |
|
rois=rois, |
|
labels=labels, |
|
label_weights=cls_score.new_ones(cls_score.size(0)), |
|
bbox_targets=None, |
|
bbox_weights=None, |
|
reduction_override='none')[self.loss_key] |
|
_, topk_loss_inds = loss.topk(num_expected) |
|
return inds[topk_loss_inds] |
|
|
|
def _sample_pos(self, |
|
assign_result, |
|
num_expected, |
|
bboxes=None, |
|
feats=None, |
|
**kwargs): |
|
"""Sample positive boxes. |
|
|
|
Args: |
|
assign_result (:obj:`AssignResult`): Assigned results |
|
num_expected (int): Number of expected positive samples |
|
bboxes (torch.Tensor, optional): Boxes. Defaults to None. |
|
feats (list[torch.Tensor], optional): Multi-level features. |
|
Defaults to None. |
|
|
|
Returns: |
|
torch.Tensor: Indices of positive samples |
|
""" |
|
|
|
pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False) |
|
if pos_inds.numel() != 0: |
|
pos_inds = pos_inds.squeeze(1) |
|
if pos_inds.numel() <= num_expected: |
|
return pos_inds |
|
else: |
|
return self.hard_mining(pos_inds, num_expected, bboxes[pos_inds], |
|
assign_result.labels[pos_inds], feats) |
|
|
|
def _sample_neg(self, |
|
assign_result, |
|
num_expected, |
|
bboxes=None, |
|
feats=None, |
|
**kwargs): |
|
"""Sample negative boxes. |
|
|
|
Args: |
|
assign_result (:obj:`AssignResult`): Assigned results |
|
num_expected (int): Number of expected negative samples |
|
bboxes (torch.Tensor, optional): Boxes. Defaults to None. |
|
feats (list[torch.Tensor], optional): Multi-level features. |
|
Defaults to None. |
|
|
|
Returns: |
|
torch.Tensor: Indices of negative samples |
|
""" |
|
|
|
neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False) |
|
if neg_inds.numel() != 0: |
|
neg_inds = neg_inds.squeeze(1) |
|
if len(neg_inds) <= num_expected: |
|
return neg_inds |
|
else: |
|
neg_labels = assign_result.labels.new_empty( |
|
neg_inds.size(0)).fill_(self.bbox_head.num_classes) |
|
return self.hard_mining(neg_inds, num_expected, bboxes[neg_inds], |
|
neg_labels, feats) |
|
|