Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
import logging | |
import numpy as np | |
from typing import Dict, List, Optional, Tuple | |
import torch | |
from torch import nn | |
from detectron2.config import configurable | |
from detectron2.structures import ImageList, Instances | |
from detectron2.utils.events import get_event_storage | |
from detectron2.modeling.backbone import Backbone, build_backbone | |
from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY | |
from detectron2.modeling.meta_arch import GeneralizedRCNN | |
from detectron2.modeling.postprocessing import detector_postprocess | |
from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference_single_image | |
from contextlib import contextmanager | |
from itertools import count | |
class VLGeneralizedRCNN(GeneralizedRCNN): | |
""" | |
Generalized R-CNN. Any models that contains the following three components: | |
1. Per-image feature extraction (aka backbone) | |
2. Region proposal generation | |
3. Per-region feature extraction and prediction | |
""" | |
def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]): | |
""" | |
Args: | |
batched_inputs: a list, batched outputs of :class:`DatasetMapper` . | |
Each item in the list contains the inputs for one image. | |
For now, each item in the list is a dict that contains: | |
* image: Tensor, image in (C, H, W) format. | |
* instances (optional): groundtruth :class:`Instances` | |
* proposals (optional): :class:`Instances`, precomputed proposals. | |
Other information that's included in the original dicts, such as: | |
* "height", "width" (int): the output resolution of the model, used in inference. | |
See :meth:`postprocess` for details. | |
Returns: | |
list[dict]: | |
Each dict is the output for one input image. | |
The dict contains one key "instances" whose value is a :class:`Instances`. | |
The :class:`Instances` object has the following keys: | |
"pred_boxes", "pred_classes", "scores", "pred_masks", "pred_keypoints" | |
""" | |
if not self.training: | |
return self.inference(batched_inputs) | |
images = self.preprocess_image(batched_inputs) | |
if "instances" in batched_inputs[0]: | |
gt_instances = [x["instances"].to(self.device) for x in batched_inputs] | |
else: | |
gt_instances = None | |
# features = self.backbone(images.tensor) | |
input = self.get_batch(batched_inputs, images) | |
features = self.backbone(input) | |
if self.proposal_generator is not None: | |
proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) | |
else: | |
assert "proposals" in batched_inputs[0] | |
proposals = [x["proposals"].to(self.device) for x in batched_inputs] | |
proposal_losses = {} | |
_, detector_losses = self.roi_heads(images, features, proposals, gt_instances) | |
if self.vis_period > 0: | |
storage = get_event_storage() | |
if storage.iter % self.vis_period == 0: | |
self.visualize_training(batched_inputs, proposals) | |
losses = {} | |
losses.update(detector_losses) | |
losses.update(proposal_losses) | |
return losses | |
def inference( | |
self, | |
batched_inputs: List[Dict[str, torch.Tensor]], | |
detected_instances: Optional[List[Instances]] = None, | |
do_postprocess: bool = True, | |
): | |
""" | |
Run inference on the given inputs. | |
Args: | |
batched_inputs (list[dict]): same as in :meth:`forward` | |
detected_instances (None or list[Instances]): if not None, it | |
contains an `Instances` object per image. The `Instances` | |
object contains "pred_boxes" and "pred_classes" which are | |
known boxes in the image. | |
The inference will then skip the detection of bounding boxes, | |
and only predict other per-ROI outputs. | |
do_postprocess (bool): whether to apply post-processing on the outputs. | |
Returns: | |
When do_postprocess=True, same as in :meth:`forward`. | |
Otherwise, a list[Instances] containing raw network outputs. | |
""" | |
assert not self.training | |
images = self.preprocess_image(batched_inputs) | |
# features = self.backbone(images.tensor) | |
input = self.get_batch(batched_inputs, images) | |
features = self.backbone(input) | |
if detected_instances is None: | |
if self.proposal_generator is not None: | |
proposals, _ = self.proposal_generator(images, features, None) | |
else: | |
assert "proposals" in batched_inputs[0] | |
proposals = [x["proposals"].to(self.device) for x in batched_inputs] | |
results, _ = self.roi_heads(images, features, proposals, None) | |
else: | |
detected_instances = [x.to(self.device) for x in detected_instances] | |
results = self.roi_heads.forward_with_given_boxes(features, detected_instances) | |
if do_postprocess: | |
assert not torch.jit.is_scripting(), "Scripting is not supported for postprocess." | |
return GeneralizedRCNN._postprocess(results, batched_inputs, images.image_sizes) | |
else: | |
return results | |
def get_batch(self, examples, images): | |
if len(examples) >= 1 and "bbox" not in examples[0]: # image_only | |
return {"images": images.tensor} | |
return input | |
def _batch_inference(self, batched_inputs, detected_instances=None): | |
""" | |
Execute inference on a list of inputs, | |
using batch size = self.batch_size (e.g., 2), instead of the length of the list. | |
Inputs & outputs have the same format as :meth:`GeneralizedRCNN.inference` | |
""" | |
if detected_instances is None: | |
detected_instances = [None] * len(batched_inputs) | |
outputs = [] | |
inputs, instances = [], [] | |
for idx, input, instance in zip(count(), batched_inputs, detected_instances): | |
inputs.append(input) | |
instances.append(instance) | |
if len(inputs) == 2 or idx == len(batched_inputs) - 1: | |
outputs.extend( | |
self.inference( | |
inputs, | |
instances if instances[0] is not None else None, | |
do_postprocess=True, # False | |
) | |
) | |
inputs, instances = [], [] | |
return outputs | |