|
|
|
from typing import List, Tuple |
|
|
|
import torch |
|
from mmcv.ops import batched_nms |
|
from mmengine.model import BaseTTAModel |
|
from mmengine.registry import MODELS |
|
from mmengine.structures import InstanceData |
|
from torch import Tensor |
|
|
|
from mmdet.structures import DetDataSample |
|
from mmdet.structures.bbox import bbox_flip |
|
|
|
|
|
@MODELS.register_module() |
|
class DetTTAModel(BaseTTAModel): |
|
"""Merge augmented detection results, only bboxes corresponding score under |
|
flipping and multi-scale resizing can be processed now. |
|
|
|
Examples: |
|
>>> tta_model = dict( |
|
>>> type='DetTTAModel', |
|
>>> tta_cfg=dict(nms=dict( |
|
>>> type='nms', |
|
>>> iou_threshold=0.5), |
|
>>> max_per_img=100)) |
|
>>> |
|
>>> tta_pipeline = [ |
|
>>> dict(type='LoadImageFromFile', |
|
>>> backend_args=None), |
|
>>> dict( |
|
>>> type='TestTimeAug', |
|
>>> transforms=[[ |
|
>>> dict(type='Resize', |
|
>>> scale=(1333, 800), |
|
>>> keep_ratio=True), |
|
>>> ], [ |
|
>>> dict(type='RandomFlip', prob=1.), |
|
>>> dict(type='RandomFlip', prob=0.) |
|
>>> ], [ |
|
>>> dict( |
|
>>> type='PackDetInputs', |
|
>>> meta_keys=('img_id', 'img_path', 'ori_shape', |
|
>>> 'img_shape', 'scale_factor', 'flip', |
|
>>> 'flip_direction')) |
|
>>> ]])] |
|
""" |
|
|
|
def __init__(self, tta_cfg=None, **kwargs): |
|
super().__init__(**kwargs) |
|
self.tta_cfg = tta_cfg |
|
|
|
def merge_aug_bboxes(self, aug_bboxes: List[Tensor], |
|
aug_scores: List[Tensor], |
|
img_metas: List[str]) -> Tuple[Tensor, Tensor]: |
|
"""Merge augmented detection bboxes and scores. |
|
|
|
Args: |
|
aug_bboxes (list[Tensor]): shape (n, 4*#class) |
|
aug_scores (list[Tensor] or None): shape (n, #class) |
|
Returns: |
|
tuple[Tensor]: ``bboxes`` with shape (n,4), where |
|
4 represent (tl_x, tl_y, br_x, br_y) |
|
and ``scores`` with shape (n,). |
|
""" |
|
recovered_bboxes = [] |
|
for bboxes, img_info in zip(aug_bboxes, img_metas): |
|
ori_shape = img_info['ori_shape'] |
|
flip = img_info['flip'] |
|
flip_direction = img_info['flip_direction'] |
|
if flip: |
|
bboxes = bbox_flip( |
|
bboxes=bboxes, |
|
img_shape=ori_shape, |
|
direction=flip_direction) |
|
recovered_bboxes.append(bboxes) |
|
bboxes = torch.cat(recovered_bboxes, dim=0) |
|
if aug_scores is None: |
|
return bboxes |
|
else: |
|
scores = torch.cat(aug_scores, dim=0) |
|
return bboxes, scores |
|
|
|
def merge_preds(self, data_samples_list: List[List[DetDataSample]]): |
|
"""Merge batch predictions of enhanced data. |
|
|
|
Args: |
|
data_samples_list (List[List[DetDataSample]]): List of predictions |
|
of all enhanced data. The outer list indicates images, and the |
|
inner list corresponds to the different views of one image. |
|
Each element of the inner list is a ``DetDataSample``. |
|
Returns: |
|
List[DetDataSample]: Merged batch prediction. |
|
""" |
|
merged_data_samples = [] |
|
for data_samples in data_samples_list: |
|
merged_data_samples.append(self._merge_single_sample(data_samples)) |
|
return merged_data_samples |
|
|
|
def _merge_single_sample( |
|
self, data_samples: List[DetDataSample]) -> DetDataSample: |
|
"""Merge predictions which come form the different views of one image |
|
to one prediction. |
|
|
|
Args: |
|
data_samples (List[DetDataSample]): List of predictions |
|
of enhanced data which come form one image. |
|
Returns: |
|
List[DetDataSample]: Merged prediction. |
|
""" |
|
aug_bboxes = [] |
|
aug_scores = [] |
|
aug_labels = [] |
|
img_metas = [] |
|
|
|
assert data_samples[0].pred_instances.get('masks', None) is None, \ |
|
'TTA of instance segmentation does not support now.' |
|
for data_sample in data_samples: |
|
aug_bboxes.append(data_sample.pred_instances.bboxes) |
|
aug_scores.append(data_sample.pred_instances.scores) |
|
aug_labels.append(data_sample.pred_instances.labels) |
|
img_metas.append(data_sample.metainfo) |
|
|
|
merged_bboxes, merged_scores = self.merge_aug_bboxes( |
|
aug_bboxes, aug_scores, img_metas) |
|
merged_labels = torch.cat(aug_labels, dim=0) |
|
|
|
if merged_bboxes.numel() == 0: |
|
return data_samples[0] |
|
|
|
det_bboxes, keep_idxs = batched_nms(merged_bboxes, merged_scores, |
|
merged_labels, self.tta_cfg.nms) |
|
|
|
det_bboxes = det_bboxes[:self.tta_cfg.max_per_img] |
|
det_labels = merged_labels[keep_idxs][:self.tta_cfg.max_per_img] |
|
|
|
results = InstanceData() |
|
_det_bboxes = det_bboxes.clone() |
|
results.bboxes = _det_bboxes[:, :-1] |
|
results.scores = _det_bboxes[:, -1] |
|
results.labels = det_labels |
|
det_results = data_samples[0] |
|
det_results.pred_instances = results |
|
return det_results |
|
|