|
|
|
from typing import Optional, Sequence, Union |
|
|
|
import torch |
|
from torch import Tensor |
|
|
|
from mmdet.registry import TASK_UTILS |
|
from mmdet.structures.bbox import BaseBoxes, HorizontalBoxes, get_box_tensor |
|
from .base_bbox_coder import BaseBBoxCoder |
|
|
|
|
|
@TASK_UTILS.register_module() |
|
class TBLRBBoxCoder(BaseBBoxCoder): |
|
"""TBLR BBox coder. |
|
|
|
Following the practice in `FSAF <https://arxiv.org/abs/1903.00621>`_, |
|
this coder encodes gt bboxes (x1, y1, x2, y2) into (top, bottom, left, |
|
right) and decode it back to the original. |
|
|
|
Args: |
|
normalizer (list | float): Normalization factor to be |
|
divided with when coding the coordinates. If it is a list, it should |
|
have length of 4 indicating normalization factor in tblr dims. |
|
Otherwise it is a unified float factor for all dims. Default: 4.0 |
|
clip_border (bool, optional): Whether clip the objects outside the |
|
border of the image. Defaults to True. |
|
""" |
|
|
|
def __init__(self, |
|
normalizer: Union[Sequence[float], float] = 4.0, |
|
clip_border: bool = True, |
|
**kwargs) -> None: |
|
super().__init__(**kwargs) |
|
self.normalizer = normalizer |
|
self.clip_border = clip_border |
|
|
|
def encode(self, bboxes: Union[Tensor, BaseBoxes], |
|
gt_bboxes: Union[Tensor, BaseBoxes]) -> Tensor: |
|
"""Get box regression transformation deltas that can be used to |
|
transform the ``bboxes`` into the ``gt_bboxes`` in the (top, left, |
|
bottom, right) order. |
|
|
|
Args: |
|
bboxes (torch.Tensor or :obj:`BaseBoxes`): source boxes, |
|
e.g., object proposals. |
|
gt_bboxes (torch.Tensor or :obj:`BaseBoxes`): target of the |
|
transformation, e.g., ground truth boxes. |
|
|
|
Returns: |
|
torch.Tensor: Box transformation deltas |
|
""" |
|
bboxes = get_box_tensor(bboxes) |
|
gt_bboxes = get_box_tensor(gt_bboxes) |
|
assert bboxes.size(0) == gt_bboxes.size(0) |
|
assert bboxes.size(-1) == gt_bboxes.size(-1) == 4 |
|
encoded_bboxes = bboxes2tblr( |
|
bboxes, gt_bboxes, normalizer=self.normalizer) |
|
return encoded_bboxes |
|
|
|
def decode( |
|
self, |
|
bboxes: Union[Tensor, BaseBoxes], |
|
pred_bboxes: Tensor, |
|
max_shape: Optional[Union[Sequence[int], Tensor, |
|
Sequence[Sequence[int]]]] = None |
|
) -> Union[Tensor, BaseBoxes]: |
|
"""Apply transformation `pred_bboxes` to `boxes`. |
|
|
|
Args: |
|
bboxes (torch.Tensor or :obj:`BaseBoxes`): Basic boxes.Shape |
|
(B, N, 4) or (N, 4) |
|
pred_bboxes (torch.Tensor): Encoded boxes with shape |
|
(B, N, 4) or (N, 4) |
|
max_shape (Sequence[int] or torch.Tensor or Sequence[ |
|
Sequence[int]],optional): Maximum bounds for boxes, specifies |
|
(H, W, C) or (H, W). If bboxes shape is (B, N, 4), then |
|
the max_shape should be a Sequence[Sequence[int]] |
|
and the length of max_shape should also be B. |
|
|
|
Returns: |
|
Union[torch.Tensor, :obj:`BaseBoxes`]: Decoded boxes. |
|
""" |
|
bboxes = get_box_tensor(bboxes) |
|
decoded_bboxes = tblr2bboxes( |
|
bboxes, |
|
pred_bboxes, |
|
normalizer=self.normalizer, |
|
max_shape=max_shape, |
|
clip_border=self.clip_border) |
|
|
|
if self.use_box_type: |
|
decoded_bboxes = HorizontalBoxes(decoded_bboxes) |
|
return decoded_bboxes |
|
|
|
|
|
def bboxes2tblr(priors: Tensor, |
|
gts: Tensor, |
|
normalizer: Union[Sequence[float], float] = 4.0, |
|
normalize_by_wh: bool = True) -> Tensor: |
|
"""Encode ground truth boxes to tblr coordinate. |
|
|
|
It first convert the gt coordinate to tblr format, |
|
(top, bottom, left, right), relative to prior box centers. |
|
The tblr coordinate may be normalized by the side length of prior bboxes |
|
if `normalize_by_wh` is specified as True, and it is then normalized by |
|
the `normalizer` factor. |
|
|
|
Args: |
|
priors (Tensor): Prior boxes in point form |
|
Shape: (num_proposals,4). |
|
gts (Tensor): Coords of ground truth for each prior in point-form |
|
Shape: (num_proposals, 4). |
|
normalizer (Sequence[float] | float): normalization parameter of |
|
encoded boxes. If it is a list, it has to have length = 4. |
|
Default: 4.0 |
|
normalize_by_wh (bool): Whether to normalize tblr coordinate by the |
|
side length (wh) of prior bboxes. |
|
|
|
Return: |
|
encoded boxes (Tensor), Shape: (num_proposals, 4) |
|
""" |
|
|
|
|
|
if not isinstance(normalizer, float): |
|
normalizer = torch.tensor(normalizer, device=priors.device) |
|
assert len(normalizer) == 4, 'Normalizer must have length = 4' |
|
assert priors.size(0) == gts.size(0) |
|
prior_centers = (priors[:, 0:2] + priors[:, 2:4]) / 2 |
|
xmin, ymin, xmax, ymax = gts.split(1, dim=1) |
|
top = prior_centers[:, 1].unsqueeze(1) - ymin |
|
bottom = ymax - prior_centers[:, 1].unsqueeze(1) |
|
left = prior_centers[:, 0].unsqueeze(1) - xmin |
|
right = xmax - prior_centers[:, 0].unsqueeze(1) |
|
loc = torch.cat((top, bottom, left, right), dim=1) |
|
if normalize_by_wh: |
|
|
|
wh = priors[:, 2:4] - priors[:, 0:2] |
|
w, h = torch.split(wh, 1, dim=1) |
|
loc[:, :2] /= h |
|
loc[:, 2:] /= w |
|
|
|
return loc / normalizer |
|
|
|
|
|
def tblr2bboxes(priors: Tensor, |
|
tblr: Tensor, |
|
normalizer: Union[Sequence[float], float] = 4.0, |
|
normalize_by_wh: bool = True, |
|
max_shape: Optional[Union[Sequence[int], Tensor, |
|
Sequence[Sequence[int]]]] = None, |
|
clip_border: bool = True) -> Tensor: |
|
"""Decode tblr outputs to prediction boxes. |
|
|
|
The process includes 3 steps: 1) De-normalize tblr coordinates by |
|
multiplying it with `normalizer`; 2) De-normalize tblr coordinates by the |
|
prior bbox width and height if `normalize_by_wh` is `True`; 3) Convert |
|
tblr (top, bottom, left, right) pair relative to the center of priors back |
|
to (xmin, ymin, xmax, ymax) coordinate. |
|
|
|
Args: |
|
priors (Tensor): Prior boxes in point form (x0, y0, x1, y1) |
|
Shape: (N,4) or (B, N, 4). |
|
tblr (Tensor): Coords of network output in tblr form |
|
Shape: (N, 4) or (B, N, 4). |
|
normalizer (Sequence[float] | float): Normalization parameter of |
|
encoded boxes. By list, it represents the normalization factors at |
|
tblr dims. By float, it is the unified normalization factor at all |
|
dims. Default: 4.0 |
|
normalize_by_wh (bool): Whether the tblr coordinates have been |
|
normalized by the side length (wh) of prior bboxes. |
|
max_shape (Sequence[int] or torch.Tensor or Sequence[ |
|
Sequence[int]],optional): Maximum bounds for boxes, specifies |
|
(H, W, C) or (H, W). If priors shape is (B, N, 4), then |
|
the max_shape should be a Sequence[Sequence[int]] |
|
and the length of max_shape should also be B. |
|
clip_border (bool, optional): Whether clip the objects outside the |
|
border of the image. Defaults to True. |
|
|
|
Return: |
|
encoded boxes (Tensor): Boxes with shape (N, 4) or (B, N, 4) |
|
""" |
|
if not isinstance(normalizer, float): |
|
normalizer = torch.tensor(normalizer, device=priors.device) |
|
assert len(normalizer) == 4, 'Normalizer must have length = 4' |
|
assert priors.size(0) == tblr.size(0) |
|
if priors.ndim == 3: |
|
assert priors.size(1) == tblr.size(1) |
|
|
|
loc_decode = tblr * normalizer |
|
prior_centers = (priors[..., 0:2] + priors[..., 2:4]) / 2 |
|
if normalize_by_wh: |
|
wh = priors[..., 2:4] - priors[..., 0:2] |
|
w, h = torch.split(wh, 1, dim=-1) |
|
|
|
th = h * loc_decode[..., :2] |
|
tw = w * loc_decode[..., 2:] |
|
loc_decode = torch.cat([th, tw], dim=-1) |
|
|
|
top, bottom, left, right = loc_decode.split((1, 1, 1, 1), dim=-1) |
|
xmin = prior_centers[..., 0].unsqueeze(-1) - left |
|
xmax = prior_centers[..., 0].unsqueeze(-1) + right |
|
ymin = prior_centers[..., 1].unsqueeze(-1) - top |
|
ymax = prior_centers[..., 1].unsqueeze(-1) + bottom |
|
|
|
bboxes = torch.cat((xmin, ymin, xmax, ymax), dim=-1) |
|
|
|
if clip_border and max_shape is not None: |
|
|
|
if torch.onnx.is_in_onnx_export(): |
|
from mmdet.core.export import dynamic_clip_for_onnx |
|
xmin, ymin, xmax, ymax = dynamic_clip_for_onnx( |
|
xmin, ymin, xmax, ymax, max_shape) |
|
bboxes = torch.cat([xmin, ymin, xmax, ymax], dim=-1) |
|
return bboxes |
|
if not isinstance(max_shape, torch.Tensor): |
|
max_shape = priors.new_tensor(max_shape) |
|
max_shape = max_shape[..., :2].type_as(priors) |
|
if max_shape.ndim == 2: |
|
assert bboxes.ndim == 3 |
|
assert max_shape.size(0) == bboxes.size(0) |
|
|
|
min_xy = priors.new_tensor(0) |
|
max_xy = torch.cat([max_shape, max_shape], |
|
dim=-1).flip(-1).unsqueeze(-2) |
|
bboxes = torch.where(bboxes < min_xy, min_xy, bboxes) |
|
bboxes = torch.where(bboxes > max_xy, max_xy, bboxes) |
|
|
|
return bboxes |
|
|