|
|
|
from typing import Tuple |
|
|
|
import torch |
|
from torch import Tensor |
|
|
|
|
|
def preprocess_panoptic_gt(gt_labels: Tensor, gt_masks: Tensor, |
|
gt_semantic_seg: Tensor, num_things: int, |
|
num_stuff: int) -> Tuple[Tensor, Tensor]: |
|
"""Preprocess the ground truth for a image. |
|
|
|
Args: |
|
gt_labels (Tensor): Ground truth labels of each bbox, |
|
with shape (num_gts, ). |
|
gt_masks (BitmapMasks): Ground truth masks of each instances |
|
of a image, shape (num_gts, h, w). |
|
gt_semantic_seg (Tensor | None): Ground truth of semantic |
|
segmentation with the shape (1, h, w). |
|
[0, num_thing_class - 1] means things, |
|
[num_thing_class, num_class-1] means stuff, |
|
255 means VOID. It's None when training instance segmentation. |
|
|
|
Returns: |
|
tuple[Tensor, Tensor]: a tuple containing the following targets. |
|
|
|
- labels (Tensor): Ground truth class indices for a |
|
image, with shape (n, ), n is the sum of number |
|
of stuff type and number of instance in a image. |
|
- masks (Tensor): Ground truth mask for a image, with |
|
shape (n, h, w). Contains stuff and things when training |
|
panoptic segmentation, and things only when training |
|
instance segmentation. |
|
""" |
|
num_classes = num_things + num_stuff |
|
things_masks = gt_masks.to_tensor( |
|
dtype=torch.bool, device=gt_labels.device) |
|
|
|
if gt_semantic_seg is None: |
|
masks = things_masks.long() |
|
return gt_labels, masks |
|
|
|
things_labels = gt_labels |
|
gt_semantic_seg = gt_semantic_seg.squeeze(0) |
|
|
|
semantic_labels = torch.unique( |
|
gt_semantic_seg, |
|
sorted=False, |
|
return_inverse=False, |
|
return_counts=False) |
|
stuff_masks_list = [] |
|
stuff_labels_list = [] |
|
for label in semantic_labels: |
|
if label < num_things or label >= num_classes: |
|
continue |
|
stuff_mask = gt_semantic_seg == label |
|
stuff_masks_list.append(stuff_mask) |
|
stuff_labels_list.append(label) |
|
|
|
if len(stuff_masks_list) > 0: |
|
stuff_masks = torch.stack(stuff_masks_list, dim=0) |
|
stuff_labels = torch.stack(stuff_labels_list, dim=0) |
|
labels = torch.cat([things_labels, stuff_labels], dim=0) |
|
masks = torch.cat([things_masks, stuff_masks], dim=0) |
|
else: |
|
labels = things_labels |
|
masks = things_masks |
|
|
|
masks = masks.long() |
|
return labels, masks |
|
|