|
|
|
import os |
|
import os.path as osp |
|
from typing import Optional, Sequence |
|
|
|
from mmengine.dist import is_main_process |
|
from mmengine.evaluator import BaseMetric |
|
from mmengine.fileio import dump |
|
from mmengine.logging import MMLogger |
|
from mmengine.structures import InstanceData |
|
|
|
from mmdet.registry import METRICS |
|
|
|
|
|
@METRICS.register_module() |
|
class DumpProposals(BaseMetric): |
|
"""Dump proposals pseudo metric. |
|
|
|
Args: |
|
output_dir (str): The root directory for ``proposals_file``. |
|
Defaults to ''. |
|
proposals_file (str): Proposals file path. Defaults to 'proposals.pkl'. |
|
num_max_proposals (int, optional): Maximum number of proposals to dump. |
|
If not specified, all proposals will be dumped. |
|
file_client_args (dict, optional): Arguments to instantiate the |
|
corresponding backend in mmdet <= 3.0.0rc6. Defaults to None. |
|
backend_args (dict, optional): Arguments to instantiate the |
|
corresponding backend. Defaults to None. |
|
collect_device (str): Device name used for collecting results from |
|
different ranks during distributed training. Must be 'cpu' or |
|
'gpu'. Defaults to 'cpu'. |
|
prefix (str, optional): The prefix that will be added in the metric |
|
names to disambiguate homonymous metrics of different evaluators. |
|
If prefix is not provided in the argument, self.default_prefix |
|
will be used instead. Defaults to None. |
|
""" |
|
|
|
default_prefix: Optional[str] = 'dump_proposals' |
|
|
|
def __init__(self, |
|
output_dir: str = '', |
|
proposals_file: str = 'proposals.pkl', |
|
num_max_proposals: Optional[int] = None, |
|
file_client_args: dict = None, |
|
backend_args: dict = None, |
|
collect_device: str = 'cpu', |
|
prefix: Optional[str] = None) -> None: |
|
super().__init__(collect_device=collect_device, prefix=prefix) |
|
self.num_max_proposals = num_max_proposals |
|
|
|
self.backend_args = backend_args |
|
if file_client_args is not None: |
|
raise RuntimeError( |
|
'The `file_client_args` is deprecated, ' |
|
'please use `backend_args` instead, please refer to' |
|
'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' |
|
) |
|
self.output_dir = output_dir |
|
assert proposals_file.endswith(('.pkl', '.pickle')), \ |
|
'The output file must be a pkl file.' |
|
|
|
self.proposals_file = os.path.join(self.output_dir, proposals_file) |
|
if is_main_process(): |
|
os.makedirs(self.output_dir, exist_ok=True) |
|
|
|
def process(self, data_batch: Sequence[dict], |
|
data_samples: Sequence[dict]) -> None: |
|
"""Process one batch of data samples and predictions. The processed |
|
results should be stored in ``self.results``, which will be used to |
|
compute the metrics when all batches have been processed. |
|
|
|
Args: |
|
data_batch (dict): A batch of data from the dataloader. |
|
data_samples (Sequence[dict]): A batch of data samples that |
|
contain annotations and predictions. |
|
""" |
|
for data_sample in data_samples: |
|
pred = data_sample['pred_instances'] |
|
|
|
ranked_scores, rank_inds = pred['scores'].sort(descending=True) |
|
ranked_bboxes = pred['bboxes'][rank_inds, :] |
|
|
|
ranked_bboxes = ranked_bboxes.cpu().numpy() |
|
ranked_scores = ranked_scores.cpu().numpy() |
|
|
|
pred_instance = InstanceData() |
|
pred_instance.bboxes = ranked_bboxes |
|
pred_instance.scores = ranked_scores |
|
if self.num_max_proposals is not None: |
|
pred_instance = pred_instance[:self.num_max_proposals] |
|
|
|
img_path = data_sample['img_path'] |
|
|
|
|
|
file_name = osp.join( |
|
osp.split(osp.split(img_path)[0])[-1], |
|
osp.split(img_path)[-1]) |
|
result = {file_name: pred_instance} |
|
self.results.append(result) |
|
|
|
def compute_metrics(self, results: list) -> dict: |
|
"""Dump the processed results. |
|
|
|
Args: |
|
results (list): The processed results of each batch. |
|
|
|
Returns: |
|
dict: An empty dict. |
|
""" |
|
logger: MMLogger = MMLogger.get_current_instance() |
|
dump_results = {} |
|
for result in results: |
|
dump_results.update(result) |
|
dump( |
|
dump_results, |
|
file=self.proposals_file, |
|
backend_args=self.backend_args) |
|
logger.info(f'Results are saved at {self.proposals_file}') |
|
return {} |
|
|