|
|
|
from typing import Dict, List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from mmcv.cnn import Conv2d |
|
from mmengine.model import caffe2_xavier_init |
|
from mmengine.structures import InstanceData, PixelData |
|
from torch import Tensor |
|
|
|
from mmdet.models.layers.pixel_decoder import PixelDecoder |
|
from mmdet.registry import MODELS, TASK_UTILS |
|
from mmdet.structures import SampleList |
|
from mmdet.utils import (ConfigType, InstanceList, OptConfigType, |
|
OptMultiConfig, reduce_mean) |
|
from ..layers import DetrTransformerDecoder, SinePositionalEncoding |
|
from ..utils import multi_apply, preprocess_panoptic_gt |
|
from .anchor_free_head import AnchorFreeHead |
|
|
|
|
|
@MODELS.register_module() |
|
class MaskFormerHead(AnchorFreeHead): |
|
"""Implements the MaskFormer head. |
|
|
|
See `Per-Pixel Classification is Not All You Need for Semantic |
|
Segmentation <https://arxiv.org/pdf/2107.06278>`_ for details. |
|
|
|
Args: |
|
in_channels (list[int]): Number of channels in the input feature map. |
|
feat_channels (int): Number of channels for feature. |
|
out_channels (int): Number of channels for output. |
|
num_things_classes (int): Number of things. |
|
num_stuff_classes (int): Number of stuff. |
|
num_queries (int): Number of query in Transformer. |
|
pixel_decoder (:obj:`ConfigDict` or dict): Config for pixel |
|
decoder. |
|
enforce_decoder_input_project (bool): Whether to add a layer |
|
to change the embed_dim of transformer encoder in pixel decoder to |
|
the embed_dim of transformer decoder. Defaults to False. |
|
transformer_decoder (:obj:`ConfigDict` or dict): Config for |
|
transformer decoder. |
|
positional_encoding (:obj:`ConfigDict` or dict): Config for |
|
transformer decoder position encoding. |
|
loss_cls (:obj:`ConfigDict` or dict): Config of the classification |
|
loss. Defaults to `CrossEntropyLoss`. |
|
loss_mask (:obj:`ConfigDict` or dict): Config of the mask loss. |
|
Defaults to `FocalLoss`. |
|
loss_dice (:obj:`ConfigDict` or dict): Config of the dice loss. |
|
Defaults to `DiceLoss`. |
|
train_cfg (:obj:`ConfigDict` or dict, optional): Training config of |
|
MaskFormer head. |
|
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of |
|
MaskFormer head. |
|
init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ |
|
dict], optional): Initialization config dict. Defaults to None. |
|
""" |
|
|
|
def __init__(self, |
|
in_channels: List[int], |
|
feat_channels: int, |
|
out_channels: int, |
|
num_things_classes: int = 80, |
|
num_stuff_classes: int = 53, |
|
num_queries: int = 100, |
|
pixel_decoder: ConfigType = ..., |
|
enforce_decoder_input_project: bool = False, |
|
transformer_decoder: ConfigType = ..., |
|
positional_encoding: ConfigType = dict( |
|
num_feats=128, normalize=True), |
|
loss_cls: ConfigType = dict( |
|
type='CrossEntropyLoss', |
|
use_sigmoid=False, |
|
loss_weight=1.0, |
|
class_weight=[1.0] * 133 + [0.1]), |
|
loss_mask: ConfigType = dict( |
|
type='FocalLoss', |
|
use_sigmoid=True, |
|
gamma=2.0, |
|
alpha=0.25, |
|
loss_weight=20.0), |
|
loss_dice: ConfigType = dict( |
|
type='DiceLoss', |
|
use_sigmoid=True, |
|
activate=True, |
|
naive_dice=True, |
|
loss_weight=1.0), |
|
train_cfg: OptConfigType = None, |
|
test_cfg: OptConfigType = None, |
|
init_cfg: OptMultiConfig = None, |
|
**kwargs) -> None: |
|
super(AnchorFreeHead, self).__init__(init_cfg=init_cfg) |
|
self.num_things_classes = num_things_classes |
|
self.num_stuff_classes = num_stuff_classes |
|
self.num_classes = self.num_things_classes + self.num_stuff_classes |
|
self.num_queries = num_queries |
|
|
|
pixel_decoder.update( |
|
in_channels=in_channels, |
|
feat_channels=feat_channels, |
|
out_channels=out_channels) |
|
self.pixel_decoder = MODELS.build(pixel_decoder) |
|
self.transformer_decoder = DetrTransformerDecoder( |
|
**transformer_decoder) |
|
self.decoder_embed_dims = self.transformer_decoder.embed_dims |
|
if type(self.pixel_decoder) == PixelDecoder and ( |
|
self.decoder_embed_dims != in_channels[-1] |
|
or enforce_decoder_input_project): |
|
self.decoder_input_proj = Conv2d( |
|
in_channels[-1], self.decoder_embed_dims, kernel_size=1) |
|
else: |
|
self.decoder_input_proj = nn.Identity() |
|
self.decoder_pe = SinePositionalEncoding(**positional_encoding) |
|
self.query_embed = nn.Embedding(self.num_queries, out_channels) |
|
|
|
self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) |
|
self.mask_embed = nn.Sequential( |
|
nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), |
|
nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), |
|
nn.Linear(feat_channels, out_channels)) |
|
|
|
self.test_cfg = test_cfg |
|
self.train_cfg = train_cfg |
|
if train_cfg: |
|
self.assigner = TASK_UTILS.build(train_cfg['assigner']) |
|
self.sampler = TASK_UTILS.build( |
|
train_cfg['sampler'], default_args=dict(context=self)) |
|
|
|
self.class_weight = loss_cls.class_weight |
|
self.loss_cls = MODELS.build(loss_cls) |
|
self.loss_mask = MODELS.build(loss_mask) |
|
self.loss_dice = MODELS.build(loss_dice) |
|
|
|
def init_weights(self) -> None: |
|
if isinstance(self.decoder_input_proj, Conv2d): |
|
caffe2_xavier_init(self.decoder_input_proj, bias=0) |
|
|
|
self.pixel_decoder.init_weights() |
|
|
|
for p in self.transformer_decoder.parameters(): |
|
if p.dim() > 1: |
|
nn.init.xavier_uniform_(p) |
|
|
|
def preprocess_gt( |
|
self, batch_gt_instances: InstanceList, |
|
batch_gt_semantic_segs: List[Optional[PixelData]]) -> InstanceList: |
|
"""Preprocess the ground truth for all images. |
|
|
|
Args: |
|
batch_gt_instances (list[:obj:`InstanceData`]): Batch of |
|
gt_instance. It usually includes ``labels``, each is |
|
ground truth labels of each bbox, with shape (num_gts, ) |
|
and ``masks``, each is ground truth masks of each instances |
|
of a image, shape (num_gts, h, w). |
|
gt_semantic_seg (list[Optional[PixelData]]): Ground truth of |
|
semantic segmentation, each with the shape (1, h, w). |
|
[0, num_thing_class - 1] means things, |
|
[num_thing_class, num_class-1] means stuff, |
|
255 means VOID. It's None when training instance segmentation. |
|
|
|
Returns: |
|
list[obj:`InstanceData`]: each contains the following keys |
|
|
|
- labels (Tensor): Ground truth class indices\ |
|
for a image, with shape (n, ), n is the sum of\ |
|
number of stuff type and number of instance in a image. |
|
- masks (Tensor): Ground truth mask for a\ |
|
image, with shape (n, h, w). |
|
""" |
|
num_things_list = [self.num_things_classes] * len(batch_gt_instances) |
|
num_stuff_list = [self.num_stuff_classes] * len(batch_gt_instances) |
|
gt_labels_list = [ |
|
gt_instances['labels'] for gt_instances in batch_gt_instances |
|
] |
|
gt_masks_list = [ |
|
gt_instances['masks'] for gt_instances in batch_gt_instances |
|
] |
|
gt_semantic_segs = [ |
|
None if gt_semantic_seg is None else gt_semantic_seg.sem_seg |
|
for gt_semantic_seg in batch_gt_semantic_segs |
|
] |
|
targets = multi_apply(preprocess_panoptic_gt, gt_labels_list, |
|
gt_masks_list, gt_semantic_segs, num_things_list, |
|
num_stuff_list) |
|
labels, masks = targets |
|
batch_gt_instances = [ |
|
InstanceData(labels=label, masks=mask) |
|
for label, mask in zip(labels, masks) |
|
] |
|
return batch_gt_instances |
|
|
|
def get_targets( |
|
self, |
|
cls_scores_list: List[Tensor], |
|
mask_preds_list: List[Tensor], |
|
batch_gt_instances: InstanceList, |
|
batch_img_metas: List[dict], |
|
return_sampling_results: bool = False |
|
) -> Tuple[List[Union[Tensor, int]]]: |
|
"""Compute classification and mask targets for all images for a decoder |
|
layer. |
|
|
|
Args: |
|
cls_scores_list (list[Tensor]): Mask score logits from a single |
|
decoder layer for all images. Each with shape (num_queries, |
|
cls_out_channels). |
|
mask_preds_list (list[Tensor]): Mask logits from a single decoder |
|
layer for all images. Each with shape (num_queries, h, w). |
|
batch_gt_instances (list[obj:`InstanceData`]): each contains |
|
``labels`` and ``masks``. |
|
batch_img_metas (list[dict]): List of image meta information. |
|
return_sampling_results (bool): Whether to return the sampling |
|
results. Defaults to False. |
|
|
|
Returns: |
|
tuple: a tuple containing the following targets. |
|
|
|
- labels_list (list[Tensor]): Labels of all images.\ |
|
Each with shape (num_queries, ). |
|
- label_weights_list (list[Tensor]): Label weights\ |
|
of all images. Each with shape (num_queries, ). |
|
- mask_targets_list (list[Tensor]): Mask targets of\ |
|
all images. Each with shape (num_queries, h, w). |
|
- mask_weights_list (list[Tensor]): Mask weights of\ |
|
all images. Each with shape (num_queries, ). |
|
- avg_factor (int): Average factor that is used to average\ |
|
the loss. When using sampling method, avg_factor is |
|
usually the sum of positive and negative priors. When |
|
using `MaskPseudoSampler`, `avg_factor` is usually equal |
|
to the number of positive priors. |
|
|
|
additional_returns: This function enables user-defined returns from |
|
`self._get_targets_single`. These returns are currently refined |
|
to properties at each feature map (i.e. having HxW dimension). |
|
The results will be concatenated after the end. |
|
""" |
|
results = multi_apply(self._get_targets_single, cls_scores_list, |
|
mask_preds_list, batch_gt_instances, |
|
batch_img_metas) |
|
(labels_list, label_weights_list, mask_targets_list, mask_weights_list, |
|
pos_inds_list, neg_inds_list, sampling_results_list) = results[:7] |
|
rest_results = list(results[7:]) |
|
|
|
avg_factor = sum( |
|
[results.avg_factor for results in sampling_results_list]) |
|
|
|
res = (labels_list, label_weights_list, mask_targets_list, |
|
mask_weights_list, avg_factor) |
|
if return_sampling_results: |
|
res = res + (sampling_results_list) |
|
|
|
return res + tuple(rest_results) |
|
|
|
def _get_targets_single(self, cls_score: Tensor, mask_pred: Tensor, |
|
gt_instances: InstanceData, |
|
img_meta: dict) -> Tuple[Tensor]: |
|
"""Compute classification and mask targets for one image. |
|
|
|
Args: |
|
cls_score (Tensor): Mask score logits from a single decoder layer |
|
for one image. Shape (num_queries, cls_out_channels). |
|
mask_pred (Tensor): Mask logits for a single decoder layer for one |
|
image. Shape (num_queries, h, w). |
|
gt_instances (:obj:`InstanceData`): It contains ``labels`` and |
|
``masks``. |
|
img_meta (dict): Image informtation. |
|
|
|
Returns: |
|
tuple: a tuple containing the following for one image. |
|
|
|
- labels (Tensor): Labels of each image. |
|
shape (num_queries, ). |
|
- label_weights (Tensor): Label weights of each image. |
|
shape (num_queries, ). |
|
- mask_targets (Tensor): Mask targets of each image. |
|
shape (num_queries, h, w). |
|
- mask_weights (Tensor): Mask weights of each image. |
|
shape (num_queries, ). |
|
- pos_inds (Tensor): Sampled positive indices for each image. |
|
- neg_inds (Tensor): Sampled negative indices for each image. |
|
- sampling_result (:obj:`SamplingResult`): Sampling results. |
|
""" |
|
gt_masks = gt_instances.masks |
|
gt_labels = gt_instances.labels |
|
|
|
target_shape = mask_pred.shape[-2:] |
|
if gt_masks.shape[0] > 0: |
|
gt_masks_downsampled = F.interpolate( |
|
gt_masks.unsqueeze(1).float(), target_shape, |
|
mode='nearest').squeeze(1).long() |
|
else: |
|
gt_masks_downsampled = gt_masks |
|
|
|
pred_instances = InstanceData(scores=cls_score, masks=mask_pred) |
|
downsampled_gt_instances = InstanceData( |
|
labels=gt_labels, masks=gt_masks_downsampled) |
|
|
|
assign_result = self.assigner.assign( |
|
pred_instances=pred_instances, |
|
gt_instances=downsampled_gt_instances, |
|
img_meta=img_meta) |
|
sampling_result = self.sampler.sample( |
|
assign_result=assign_result, |
|
pred_instances=pred_instances, |
|
gt_instances=gt_instances) |
|
pos_inds = sampling_result.pos_inds |
|
neg_inds = sampling_result.neg_inds |
|
|
|
|
|
labels = gt_labels.new_full((self.num_queries, ), |
|
self.num_classes, |
|
dtype=torch.long) |
|
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] |
|
label_weights = gt_labels.new_ones(self.num_queries) |
|
|
|
|
|
mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds] |
|
mask_weights = mask_pred.new_zeros((self.num_queries, )) |
|
mask_weights[pos_inds] = 1.0 |
|
|
|
return (labels, label_weights, mask_targets, mask_weights, pos_inds, |
|
neg_inds, sampling_result) |
|
|
|
def loss_by_feat(self, all_cls_scores: Tensor, all_mask_preds: Tensor, |
|
batch_gt_instances: List[InstanceData], |
|
batch_img_metas: List[dict]) -> Dict[str, Tensor]: |
|
"""Loss function. |
|
|
|
Args: |
|
all_cls_scores (Tensor): Classification scores for all decoder |
|
layers with shape (num_decoder, batch_size, num_queries, |
|
cls_out_channels). Note `cls_out_channels` should includes |
|
background. |
|
all_mask_preds (Tensor): Mask scores for all decoder layers with |
|
shape (num_decoder, batch_size, num_queries, h, w). |
|
batch_gt_instances (list[obj:`InstanceData`]): each contains |
|
``labels`` and ``masks``. |
|
batch_img_metas (list[dict]): List of image meta information. |
|
|
|
Returns: |
|
dict[str, Tensor]: A dictionary of loss components. |
|
""" |
|
num_dec_layers = len(all_cls_scores) |
|
batch_gt_instances_list = [ |
|
batch_gt_instances for _ in range(num_dec_layers) |
|
] |
|
img_metas_list = [batch_img_metas for _ in range(num_dec_layers)] |
|
losses_cls, losses_mask, losses_dice = multi_apply( |
|
self._loss_by_feat_single, all_cls_scores, all_mask_preds, |
|
batch_gt_instances_list, img_metas_list) |
|
|
|
loss_dict = dict() |
|
|
|
loss_dict['loss_cls'] = losses_cls[-1] |
|
loss_dict['loss_mask'] = losses_mask[-1] |
|
loss_dict['loss_dice'] = losses_dice[-1] |
|
|
|
num_dec_layer = 0 |
|
for loss_cls_i, loss_mask_i, loss_dice_i in zip( |
|
losses_cls[:-1], losses_mask[:-1], losses_dice[:-1]): |
|
loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i |
|
loss_dict[f'd{num_dec_layer}.loss_mask'] = loss_mask_i |
|
loss_dict[f'd{num_dec_layer}.loss_dice'] = loss_dice_i |
|
num_dec_layer += 1 |
|
return loss_dict |
|
|
|
def _loss_by_feat_single(self, cls_scores: Tensor, mask_preds: Tensor, |
|
batch_gt_instances: List[InstanceData], |
|
batch_img_metas: List[dict]) -> Tuple[Tensor]: |
|
"""Loss function for outputs from a single decoder layer. |
|
|
|
Args: |
|
cls_scores (Tensor): Mask score logits from a single decoder layer |
|
for all images. Shape (batch_size, num_queries, |
|
cls_out_channels). Note `cls_out_channels` should includes |
|
background. |
|
mask_preds (Tensor): Mask logits for a pixel decoder for all |
|
images. Shape (batch_size, num_queries, h, w). |
|
batch_gt_instances (list[obj:`InstanceData`]): each contains |
|
``labels`` and ``masks``. |
|
batch_img_metas (list[dict]): List of image meta information. |
|
|
|
Returns: |
|
tuple[Tensor]: Loss components for outputs from a single decoder\ |
|
layer. |
|
""" |
|
num_imgs = cls_scores.size(0) |
|
cls_scores_list = [cls_scores[i] for i in range(num_imgs)] |
|
mask_preds_list = [mask_preds[i] for i in range(num_imgs)] |
|
|
|
(labels_list, label_weights_list, mask_targets_list, mask_weights_list, |
|
avg_factor) = self.get_targets(cls_scores_list, mask_preds_list, |
|
batch_gt_instances, batch_img_metas) |
|
|
|
labels = torch.stack(labels_list, dim=0) |
|
|
|
label_weights = torch.stack(label_weights_list, dim=0) |
|
|
|
mask_targets = torch.cat(mask_targets_list, dim=0) |
|
|
|
mask_weights = torch.stack(mask_weights_list, dim=0) |
|
|
|
|
|
|
|
cls_scores = cls_scores.flatten(0, 1) |
|
labels = labels.flatten(0, 1) |
|
label_weights = label_weights.flatten(0, 1) |
|
|
|
class_weight = cls_scores.new_tensor(self.class_weight) |
|
loss_cls = self.loss_cls( |
|
cls_scores, |
|
labels, |
|
label_weights, |
|
avg_factor=class_weight[labels].sum()) |
|
|
|
num_total_masks = reduce_mean(cls_scores.new_tensor([avg_factor])) |
|
num_total_masks = max(num_total_masks, 1) |
|
|
|
|
|
|
|
mask_preds = mask_preds[mask_weights > 0] |
|
target_shape = mask_targets.shape[-2:] |
|
|
|
if mask_targets.shape[0] == 0: |
|
|
|
loss_dice = mask_preds.sum() |
|
loss_mask = mask_preds.sum() |
|
return loss_cls, loss_mask, loss_dice |
|
|
|
|
|
|
|
mask_preds = F.interpolate( |
|
mask_preds.unsqueeze(1), |
|
target_shape, |
|
mode='bilinear', |
|
align_corners=False).squeeze(1) |
|
|
|
|
|
loss_dice = self.loss_dice( |
|
mask_preds, mask_targets, avg_factor=num_total_masks) |
|
|
|
|
|
|
|
h, w = mask_preds.shape[-2:] |
|
|
|
mask_preds = mask_preds.reshape(-1, 1) |
|
|
|
mask_targets = mask_targets.reshape(-1) |
|
|
|
loss_mask = self.loss_mask( |
|
mask_preds, 1 - mask_targets, avg_factor=num_total_masks * h * w) |
|
|
|
return loss_cls, loss_mask, loss_dice |
|
|
|
def forward(self, x: Tuple[Tensor], |
|
batch_data_samples: SampleList) -> Tuple[Tensor]: |
|
"""Forward function. |
|
|
|
Args: |
|
x (tuple[Tensor]): Features from the upstream network, each |
|
is a 4D-tensor. |
|
batch_data_samples (List[:obj:`DetDataSample`]): The Data |
|
Samples. It usually includes information such as |
|
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. |
|
|
|
Returns: |
|
tuple[Tensor]: a tuple contains two elements. |
|
|
|
- all_cls_scores (Tensor): Classification scores for each\ |
|
scale level. Each is a 4D-tensor with shape\ |
|
(num_decoder, batch_size, num_queries, cls_out_channels).\ |
|
Note `cls_out_channels` should includes background. |
|
- all_mask_preds (Tensor): Mask scores for each decoder\ |
|
layer. Each with shape (num_decoder, batch_size,\ |
|
num_queries, h, w). |
|
""" |
|
batch_img_metas = [ |
|
data_sample.metainfo for data_sample in batch_data_samples |
|
] |
|
batch_size = len(batch_img_metas) |
|
input_img_h, input_img_w = batch_img_metas[0]['batch_input_shape'] |
|
padding_mask = x[-1].new_ones((batch_size, input_img_h, input_img_w), |
|
dtype=torch.float32) |
|
for i in range(batch_size): |
|
img_h, img_w = batch_img_metas[i]['img_shape'] |
|
padding_mask[i, :img_h, :img_w] = 0 |
|
padding_mask = F.interpolate( |
|
padding_mask.unsqueeze(1), size=x[-1].shape[-2:], |
|
mode='nearest').to(torch.bool).squeeze(1) |
|
|
|
|
|
mask_features, memory = self.pixel_decoder(x, batch_img_metas) |
|
pos_embed = self.decoder_pe(padding_mask) |
|
memory = self.decoder_input_proj(memory) |
|
|
|
memory = memory.flatten(2).permute(0, 2, 1) |
|
pos_embed = pos_embed.flatten(2).permute(0, 2, 1) |
|
|
|
padding_mask = padding_mask.flatten(1) |
|
|
|
query_embed = self.query_embed.weight |
|
|
|
query_embed = query_embed.unsqueeze(0).repeat(batch_size, 1, 1) |
|
target = torch.zeros_like(query_embed) |
|
|
|
out_dec = self.transformer_decoder( |
|
query=target, |
|
key=memory, |
|
value=memory, |
|
query_pos=query_embed, |
|
key_pos=pos_embed, |
|
key_padding_mask=padding_mask) |
|
|
|
|
|
all_cls_scores = self.cls_embed(out_dec) |
|
|
|
|
|
mask_embed = self.mask_embed(out_dec) |
|
all_mask_preds = torch.einsum('lbqc,bchw->lbqhw', mask_embed, |
|
mask_features) |
|
|
|
return all_cls_scores, all_mask_preds |
|
|
|
def loss( |
|
self, |
|
x: Tuple[Tensor], |
|
batch_data_samples: SampleList, |
|
) -> Dict[str, Tensor]: |
|
"""Perform forward propagation and loss calculation of the panoptic |
|
head on the features of the upstream network. |
|
|
|
Args: |
|
x (tuple[Tensor]): Multi-level features from the upstream |
|
network, each is a 4D-tensor. |
|
batch_data_samples (List[:obj:`DetDataSample`]): The Data |
|
Samples. It usually includes information such as |
|
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. |
|
|
|
Returns: |
|
dict[str, Tensor]: a dictionary of loss components |
|
""" |
|
batch_img_metas = [] |
|
batch_gt_instances = [] |
|
batch_gt_semantic_segs = [] |
|
for data_sample in batch_data_samples: |
|
batch_img_metas.append(data_sample.metainfo) |
|
batch_gt_instances.append(data_sample.gt_instances) |
|
if 'gt_sem_seg' in data_sample: |
|
batch_gt_semantic_segs.append(data_sample.gt_sem_seg) |
|
else: |
|
batch_gt_semantic_segs.append(None) |
|
|
|
|
|
all_cls_scores, all_mask_preds = self(x, batch_data_samples) |
|
|
|
|
|
batch_gt_instances = self.preprocess_gt(batch_gt_instances, |
|
batch_gt_semantic_segs) |
|
|
|
|
|
losses = self.loss_by_feat(all_cls_scores, all_mask_preds, |
|
batch_gt_instances, batch_img_metas) |
|
|
|
return losses |
|
|
|
def predict(self, x: Tuple[Tensor], |
|
batch_data_samples: SampleList) -> Tuple[Tensor]: |
|
"""Test without augmentaton. |
|
|
|
Args: |
|
x (tuple[Tensor]): Multi-level features from the |
|
upstream network, each is a 4D-tensor. |
|
batch_data_samples (List[:obj:`DetDataSample`]): The Data |
|
Samples. It usually includes information such as |
|
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. |
|
|
|
Returns: |
|
tuple[Tensor]: A tuple contains two tensors. |
|
|
|
- mask_cls_results (Tensor): Mask classification logits,\ |
|
shape (batch_size, num_queries, cls_out_channels). |
|
Note `cls_out_channels` should includes background. |
|
- mask_pred_results (Tensor): Mask logits, shape \ |
|
(batch_size, num_queries, h, w). |
|
""" |
|
batch_img_metas = [ |
|
data_sample.metainfo for data_sample in batch_data_samples |
|
] |
|
all_cls_scores, all_mask_preds = self(x, batch_data_samples) |
|
mask_cls_results = all_cls_scores[-1] |
|
mask_pred_results = all_mask_preds[-1] |
|
|
|
|
|
img_shape = batch_img_metas[0]['batch_input_shape'] |
|
mask_pred_results = F.interpolate( |
|
mask_pred_results, |
|
size=(img_shape[0], img_shape[1]), |
|
mode='bilinear', |
|
align_corners=False) |
|
|
|
return mask_cls_results, mask_pred_results |
|
|