|
|
|
from typing import Union |
|
|
|
import torch |
|
from torch import Tensor |
|
|
|
from mmdet.registry import TASK_UTILS |
|
from mmdet.structures.bbox import BaseBoxes, HorizontalBoxes, get_box_tensor |
|
from .base_bbox_coder import BaseBBoxCoder |
|
|
|
|
|
@TASK_UTILS.register_module() |
|
class YOLOBBoxCoder(BaseBBoxCoder): |
|
"""YOLO BBox coder. |
|
|
|
Following `YOLO <https://arxiv.org/abs/1506.02640>`_, this coder divide |
|
image into grids, and encode bbox (x1, y1, x2, y2) into (cx, cy, dw, dh). |
|
cx, cy in [0., 1.], denotes relative center position w.r.t the center of |
|
bboxes. dw, dh are the same as :obj:`DeltaXYWHBBoxCoder`. |
|
|
|
Args: |
|
eps (float): Min value of cx, cy when encoding. |
|
""" |
|
|
|
def __init__(self, eps: float = 1e-6, **kwargs): |
|
super().__init__(**kwargs) |
|
self.eps = eps |
|
|
|
def encode(self, bboxes: Union[Tensor, BaseBoxes], |
|
gt_bboxes: Union[Tensor, BaseBoxes], |
|
stride: Union[Tensor, int]) -> Tensor: |
|
"""Get box regression transformation deltas that can be used to |
|
transform the ``bboxes`` into the ``gt_bboxes``. |
|
|
|
Args: |
|
bboxes (torch.Tensor or :obj:`BaseBoxes`): Source boxes, |
|
e.g., anchors. |
|
gt_bboxes (torch.Tensor or :obj:`BaseBoxes`): Target of the |
|
transformation, e.g., ground-truth boxes. |
|
stride (torch.Tensor | int): Stride of bboxes. |
|
|
|
Returns: |
|
torch.Tensor: Box transformation deltas |
|
""" |
|
bboxes = get_box_tensor(bboxes) |
|
gt_bboxes = get_box_tensor(gt_bboxes) |
|
assert bboxes.size(0) == gt_bboxes.size(0) |
|
assert bboxes.size(-1) == gt_bboxes.size(-1) == 4 |
|
x_center_gt = (gt_bboxes[..., 0] + gt_bboxes[..., 2]) * 0.5 |
|
y_center_gt = (gt_bboxes[..., 1] + gt_bboxes[..., 3]) * 0.5 |
|
w_gt = gt_bboxes[..., 2] - gt_bboxes[..., 0] |
|
h_gt = gt_bboxes[..., 3] - gt_bboxes[..., 1] |
|
x_center = (bboxes[..., 0] + bboxes[..., 2]) * 0.5 |
|
y_center = (bboxes[..., 1] + bboxes[..., 3]) * 0.5 |
|
w = bboxes[..., 2] - bboxes[..., 0] |
|
h = bboxes[..., 3] - bboxes[..., 1] |
|
w_target = torch.log((w_gt / w).clamp(min=self.eps)) |
|
h_target = torch.log((h_gt / h).clamp(min=self.eps)) |
|
x_center_target = ((x_center_gt - x_center) / stride + 0.5).clamp( |
|
self.eps, 1 - self.eps) |
|
y_center_target = ((y_center_gt - y_center) / stride + 0.5).clamp( |
|
self.eps, 1 - self.eps) |
|
encoded_bboxes = torch.stack( |
|
[x_center_target, y_center_target, w_target, h_target], dim=-1) |
|
return encoded_bboxes |
|
|
|
def decode(self, bboxes: Union[Tensor, BaseBoxes], pred_bboxes: Tensor, |
|
stride: Union[Tensor, int]) -> Union[Tensor, BaseBoxes]: |
|
"""Apply transformation `pred_bboxes` to `boxes`. |
|
|
|
Args: |
|
boxes (torch.Tensor or :obj:`BaseBoxes`): Basic boxes, |
|
e.g. anchors. |
|
pred_bboxes (torch.Tensor): Encoded boxes with shape |
|
stride (torch.Tensor | int): Strides of bboxes. |
|
|
|
Returns: |
|
Union[torch.Tensor, :obj:`BaseBoxes`]: Decoded boxes. |
|
""" |
|
bboxes = get_box_tensor(bboxes) |
|
assert pred_bboxes.size(-1) == bboxes.size(-1) == 4 |
|
xy_centers = (bboxes[..., :2] + bboxes[..., 2:]) * 0.5 + ( |
|
pred_bboxes[..., :2] - 0.5) * stride |
|
whs = (bboxes[..., 2:] - |
|
bboxes[..., :2]) * 0.5 * pred_bboxes[..., 2:].exp() |
|
decoded_bboxes = torch.stack( |
|
(xy_centers[..., 0] - whs[..., 0], xy_centers[..., 1] - |
|
whs[..., 1], xy_centers[..., 0] + whs[..., 0], |
|
xy_centers[..., 1] + whs[..., 1]), |
|
dim=-1) |
|
|
|
if self.use_box_type: |
|
decoded_bboxes = HorizontalBoxes(decoded_bboxes) |
|
return decoded_bboxes |
|
|