|
|
|
from typing import Dict, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from mmengine.model import ModuleList |
|
from torch import Tensor |
|
|
|
from mmdet.registry import MODELS |
|
from mmdet.structures import SampleList |
|
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig |
|
from ..layers import ConvUpsample |
|
from ..utils import interpolate_as |
|
from .base_semantic_head import BaseSemanticHead |
|
|
|
|
|
@MODELS.register_module() |
|
class PanopticFPNHead(BaseSemanticHead): |
|
"""PanopticFPNHead used in Panoptic FPN. |
|
|
|
In this head, the number of output channels is ``num_stuff_classes |
|
+ 1``, including all stuff classes and one thing class. The stuff |
|
classes will be reset from ``0`` to ``num_stuff_classes - 1``, the |
|
thing classes will be merged to ``num_stuff_classes``-th channel. |
|
|
|
Arg: |
|
num_things_classes (int): Number of thing classes. Default: 80. |
|
num_stuff_classes (int): Number of stuff classes. Default: 53. |
|
in_channels (int): Number of channels in the input feature |
|
map. |
|
inner_channels (int): Number of channels in inner features. |
|
start_level (int): The start level of the input features |
|
used in PanopticFPN. |
|
end_level (int): The end level of the used features, the |
|
``end_level``-th layer will not be used. |
|
conv_cfg (Optional[Union[ConfigDict, dict]]): Dictionary to construct |
|
and config conv layer. |
|
norm_cfg (Union[ConfigDict, dict]): Dictionary to construct and config |
|
norm layer. Use ``GN`` by default. |
|
init_cfg (Optional[Union[ConfigDict, dict]]): Initialization config |
|
dict. |
|
loss_seg (Union[ConfigDict, dict]): the loss of the semantic head. |
|
""" |
|
|
|
def __init__(self, |
|
num_things_classes: int = 80, |
|
num_stuff_classes: int = 53, |
|
in_channels: int = 256, |
|
inner_channels: int = 128, |
|
start_level: int = 0, |
|
end_level: int = 4, |
|
conv_cfg: OptConfigType = None, |
|
norm_cfg: ConfigType = dict( |
|
type='GN', num_groups=32, requires_grad=True), |
|
loss_seg: ConfigType = dict( |
|
type='CrossEntropyLoss', ignore_index=-1, |
|
loss_weight=1.0), |
|
init_cfg: OptMultiConfig = None) -> None: |
|
seg_rescale_factor = 1 / 2**(start_level + 2) |
|
super().__init__( |
|
num_classes=num_stuff_classes + 1, |
|
seg_rescale_factor=seg_rescale_factor, |
|
loss_seg=loss_seg, |
|
init_cfg=init_cfg) |
|
self.num_things_classes = num_things_classes |
|
self.num_stuff_classes = num_stuff_classes |
|
|
|
self.start_level = start_level |
|
self.end_level = end_level |
|
self.num_stages = end_level - start_level |
|
self.inner_channels = inner_channels |
|
|
|
self.conv_upsample_layers = ModuleList() |
|
for i in range(start_level, end_level): |
|
self.conv_upsample_layers.append( |
|
ConvUpsample( |
|
in_channels, |
|
inner_channels, |
|
num_layers=i if i > 0 else 1, |
|
num_upsample=i if i > 0 else 0, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
)) |
|
self.conv_logits = nn.Conv2d(inner_channels, self.num_classes, 1) |
|
|
|
def _set_things_to_void(self, gt_semantic_seg: Tensor) -> Tensor: |
|
"""Merge thing classes to one class. |
|
|
|
In PanopticFPN, the background labels will be reset from `0` to |
|
`self.num_stuff_classes-1`, the foreground labels will be merged to |
|
`self.num_stuff_classes`-th channel. |
|
""" |
|
gt_semantic_seg = gt_semantic_seg.int() |
|
fg_mask = gt_semantic_seg < self.num_things_classes |
|
bg_mask = (gt_semantic_seg >= self.num_things_classes) * ( |
|
gt_semantic_seg < self.num_things_classes + self.num_stuff_classes) |
|
|
|
new_gt_seg = torch.clone(gt_semantic_seg) |
|
new_gt_seg = torch.where(bg_mask, |
|
gt_semantic_seg - self.num_things_classes, |
|
new_gt_seg) |
|
new_gt_seg = torch.where(fg_mask, |
|
fg_mask.int() * self.num_stuff_classes, |
|
new_gt_seg) |
|
return new_gt_seg |
|
|
|
def loss(self, x: Union[Tensor, Tuple[Tensor]], |
|
batch_data_samples: SampleList) -> Dict[str, Tensor]: |
|
""" |
|
Args: |
|
x (Union[Tensor, Tuple[Tensor]]): Feature maps. |
|
batch_data_samples (list[:obj:`DetDataSample`]): The batch |
|
data samples. It usually includes information such |
|
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. |
|
|
|
Returns: |
|
Dict[str, Tensor]: The loss of semantic head. |
|
""" |
|
seg_preds = self(x)['seg_preds'] |
|
gt_semantic_segs = [ |
|
data_sample.gt_sem_seg.sem_seg |
|
for data_sample in batch_data_samples |
|
] |
|
|
|
gt_semantic_segs = torch.stack(gt_semantic_segs) |
|
if self.seg_rescale_factor != 1.0: |
|
gt_semantic_segs = F.interpolate( |
|
gt_semantic_segs.float(), |
|
scale_factor=self.seg_rescale_factor, |
|
mode='nearest').squeeze(1) |
|
|
|
|
|
gt_semantic_segs = self._set_things_to_void(gt_semantic_segs) |
|
|
|
if seg_preds.shape[-2:] != gt_semantic_segs.shape[-2:]: |
|
seg_preds = interpolate_as(seg_preds, gt_semantic_segs) |
|
seg_preds = seg_preds.permute((0, 2, 3, 1)) |
|
|
|
loss_seg = self.loss_seg( |
|
seg_preds.reshape(-1, self.num_classes), |
|
gt_semantic_segs.reshape(-1).long()) |
|
|
|
return dict(loss_seg=loss_seg) |
|
|
|
def init_weights(self) -> None: |
|
"""Initialize weights.""" |
|
super().init_weights() |
|
nn.init.normal_(self.conv_logits.weight.data, 0, 0.01) |
|
self.conv_logits.bias.data.zero_() |
|
|
|
def forward(self, x: Tuple[Tensor]) -> Dict[str, Tensor]: |
|
"""Forward. |
|
|
|
Args: |
|
x (Tuple[Tensor]): Multi scale Feature maps. |
|
|
|
Returns: |
|
dict[str, Tensor]: semantic segmentation predictions and |
|
feature maps. |
|
""" |
|
|
|
|
|
assert self.num_stages <= len(x) |
|
|
|
feats = [] |
|
for i, layer in enumerate(self.conv_upsample_layers): |
|
f = layer(x[self.start_level + i]) |
|
feats.append(f) |
|
|
|
seg_feats = torch.sum(torch.stack(feats, dim=0), dim=0) |
|
seg_preds = self.conv_logits(seg_feats) |
|
out = dict(seg_preds=seg_preds, seg_feats=seg_feats) |
|
return out |
|
|