|
|
|
import logging |
|
import numpy as np |
|
from typing import Dict, List, Optional, Tuple |
|
import torch |
|
from torch import nn |
|
|
|
from detectron2.config import configurable |
|
from detectron2.structures import ImageList, Instances |
|
from detectron2.utils.events import get_event_storage |
|
|
|
from detectron2.modeling.backbone import Backbone, build_backbone |
|
from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY |
|
|
|
from detectron2.modeling.meta_arch import GeneralizedRCNN |
|
|
|
from detectron2.modeling.postprocessing import detector_postprocess |
|
from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference_single_image |
|
from contextlib import contextmanager |
|
from itertools import count |
|
|
|
@META_ARCH_REGISTRY.register() |
|
class VLGeneralizedRCNN(GeneralizedRCNN): |
|
""" |
|
Generalized R-CNN. Any models that contains the following three components: |
|
1. Per-image feature extraction (aka backbone) |
|
2. Region proposal generation |
|
3. Per-region feature extraction and prediction |
|
""" |
|
|
|
def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]): |
|
""" |
|
Args: |
|
batched_inputs: a list, batched outputs of :class:`DatasetMapper` . |
|
Each item in the list contains the inputs for one image. |
|
For now, each item in the list is a dict that contains: |
|
|
|
* image: Tensor, image in (C, H, W) format. |
|
* instances (optional): groundtruth :class:`Instances` |
|
* proposals (optional): :class:`Instances`, precomputed proposals. |
|
|
|
Other information that's included in the original dicts, such as: |
|
|
|
* "height", "width" (int): the output resolution of the model, used in inference. |
|
See :meth:`postprocess` for details. |
|
|
|
Returns: |
|
list[dict]: |
|
Each dict is the output for one input image. |
|
The dict contains one key "instances" whose value is a :class:`Instances`. |
|
The :class:`Instances` object has the following keys: |
|
"pred_boxes", "pred_classes", "scores", "pred_masks", "pred_keypoints" |
|
""" |
|
if not self.training: |
|
return self.inference(batched_inputs) |
|
|
|
images = self.preprocess_image(batched_inputs) |
|
if "instances" in batched_inputs[0]: |
|
gt_instances = [x["instances"].to(self.device) for x in batched_inputs] |
|
else: |
|
gt_instances = None |
|
|
|
|
|
input = self.get_batch(batched_inputs, images) |
|
features = self.backbone(input) |
|
|
|
if self.proposal_generator is not None: |
|
proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) |
|
else: |
|
assert "proposals" in batched_inputs[0] |
|
proposals = [x["proposals"].to(self.device) for x in batched_inputs] |
|
proposal_losses = {} |
|
|
|
_, detector_losses = self.roi_heads(images, features, proposals, gt_instances) |
|
if self.vis_period > 0: |
|
storage = get_event_storage() |
|
if storage.iter % self.vis_period == 0: |
|
self.visualize_training(batched_inputs, proposals) |
|
|
|
losses = {} |
|
losses.update(detector_losses) |
|
losses.update(proposal_losses) |
|
return losses |
|
|
|
def inference( |
|
self, |
|
batched_inputs: List[Dict[str, torch.Tensor]], |
|
detected_instances: Optional[List[Instances]] = None, |
|
do_postprocess: bool = True, |
|
): |
|
""" |
|
Run inference on the given inputs. |
|
|
|
Args: |
|
batched_inputs (list[dict]): same as in :meth:`forward` |
|
detected_instances (None or list[Instances]): if not None, it |
|
contains an `Instances` object per image. The `Instances` |
|
object contains "pred_boxes" and "pred_classes" which are |
|
known boxes in the image. |
|
The inference will then skip the detection of bounding boxes, |
|
and only predict other per-ROI outputs. |
|
do_postprocess (bool): whether to apply post-processing on the outputs. |
|
|
|
Returns: |
|
When do_postprocess=True, same as in :meth:`forward`. |
|
Otherwise, a list[Instances] containing raw network outputs. |
|
""" |
|
assert not self.training |
|
|
|
images = self.preprocess_image(batched_inputs) |
|
|
|
input = self.get_batch(batched_inputs, images) |
|
features = self.backbone(input) |
|
|
|
if detected_instances is None: |
|
if self.proposal_generator is not None: |
|
proposals, _ = self.proposal_generator(images, features, None) |
|
else: |
|
assert "proposals" in batched_inputs[0] |
|
proposals = [x["proposals"].to(self.device) for x in batched_inputs] |
|
|
|
results, _ = self.roi_heads(images, features, proposals, None) |
|
else: |
|
detected_instances = [x.to(self.device) for x in detected_instances] |
|
results = self.roi_heads.forward_with_given_boxes(features, detected_instances) |
|
|
|
if do_postprocess: |
|
assert not torch.jit.is_scripting(), "Scripting is not supported for postprocess." |
|
return GeneralizedRCNN._postprocess(results, batched_inputs, images.image_sizes) |
|
else: |
|
return results |
|
|
|
def get_batch(self, examples, images): |
|
if len(examples) >= 1 and "bbox" not in examples[0]: |
|
return {"images": images.tensor} |
|
|
|
return input |
|
|
|
def _batch_inference(self, batched_inputs, detected_instances=None): |
|
""" |
|
Execute inference on a list of inputs, |
|
using batch size = self.batch_size (e.g., 2), instead of the length of the list. |
|
|
|
Inputs & outputs have the same format as :meth:`GeneralizedRCNN.inference` |
|
""" |
|
if detected_instances is None: |
|
detected_instances = [None] * len(batched_inputs) |
|
|
|
outputs = [] |
|
inputs, instances = [], [] |
|
for idx, input, instance in zip(count(), batched_inputs, detected_instances): |
|
inputs.append(input) |
|
instances.append(instance) |
|
if len(inputs) == 2 or idx == len(batched_inputs) - 1: |
|
outputs.extend( |
|
self.inference( |
|
inputs, |
|
instances if instances[0] is not None else None, |
|
do_postprocess=True, |
|
) |
|
) |
|
inputs, instances = [], [] |
|
return outputs |
|
|