|
|
|
import torch |
|
|
|
from mmdet.registry import TASK_UTILS |
|
from mmdet.structures.bbox import bbox_overlaps, get_box_tensor |
|
|
|
|
|
def cast_tensor_type(x, scale=1., dtype=None): |
|
if dtype == 'fp16': |
|
|
|
x = (x / scale).half() |
|
return x |
|
|
|
|
|
@TASK_UTILS.register_module() |
|
class BboxOverlaps2D: |
|
"""2D Overlaps (e.g. IoUs, GIoUs) Calculator.""" |
|
|
|
def __init__(self, scale=1., dtype=None): |
|
self.scale = scale |
|
self.dtype = dtype |
|
|
|
def __call__(self, bboxes1, bboxes2, mode='iou', is_aligned=False): |
|
"""Calculate IoU between 2D bboxes. |
|
|
|
Args: |
|
bboxes1 (Tensor or :obj:`BaseBoxes`): bboxes have shape (m, 4) |
|
in <x1, y1, x2, y2> format, or shape (m, 5) in <x1, y1, x2, |
|
y2, score> format. |
|
bboxes2 (Tensor or :obj:`BaseBoxes`): bboxes have shape (m, 4) |
|
in <x1, y1, x2, y2> format, shape (m, 5) in <x1, y1, x2, y2, |
|
score> format, or be empty. If ``is_aligned `` is ``True``, |
|
then m and n must be equal. |
|
mode (str): "iou" (intersection over union), "iof" (intersection |
|
over foreground), or "giou" (generalized intersection over |
|
union). |
|
is_aligned (bool, optional): If True, then m and n must be equal. |
|
Default False. |
|
|
|
Returns: |
|
Tensor: shape (m, n) if ``is_aligned `` is False else shape (m,) |
|
""" |
|
bboxes1 = get_box_tensor(bboxes1) |
|
bboxes2 = get_box_tensor(bboxes2) |
|
assert bboxes1.size(-1) in [0, 4, 5] |
|
assert bboxes2.size(-1) in [0, 4, 5] |
|
if bboxes2.size(-1) == 5: |
|
bboxes2 = bboxes2[..., :4] |
|
if bboxes1.size(-1) == 5: |
|
bboxes1 = bboxes1[..., :4] |
|
|
|
if self.dtype == 'fp16': |
|
|
|
bboxes1 = cast_tensor_type(bboxes1, self.scale, self.dtype) |
|
bboxes2 = cast_tensor_type(bboxes2, self.scale, self.dtype) |
|
overlaps = bbox_overlaps(bboxes1, bboxes2, mode, is_aligned) |
|
if not overlaps.is_cuda and overlaps.dtype == torch.float16: |
|
|
|
overlaps = overlaps.float() |
|
return overlaps |
|
|
|
return bbox_overlaps(bboxes1, bboxes2, mode, is_aligned) |
|
|
|
def __repr__(self): |
|
"""str: a string describing the module""" |
|
repr_str = self.__class__.__name__ + f'(' \ |
|
f'scale={self.scale}, dtype={self.dtype})' |
|
return repr_str |
|
|