|
|
|
from typing import List |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from mmengine import MessageHub |
|
from mmengine.structures import InstanceData |
|
from torch import Tensor |
|
|
|
from mmdet.registry import MODELS |
|
from mmdet.utils import InstanceList |
|
from ..utils.misc import unfold_wo_center |
|
from .condinst_head import CondInstBboxHead, CondInstMaskHead |
|
|
|
|
|
@MODELS.register_module() |
|
class BoxInstBboxHead(CondInstBboxHead): |
|
"""BoxInst box head used in https://arxiv.org/abs/2012.02310.""" |
|
|
|
def __init__(self, *args, **kwargs) -> None: |
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
@MODELS.register_module() |
|
class BoxInstMaskHead(CondInstMaskHead): |
|
"""BoxInst mask head used in https://arxiv.org/abs/2012.02310. |
|
|
|
This head outputs the mask for BoxInst. |
|
|
|
Args: |
|
pairwise_size (dict): The size of neighborhood for each pixel. |
|
Defaults to 3. |
|
pairwise_dilation (int): The dilation of neighborhood for each pixel. |
|
Defaults to 2. |
|
warmup_iters (int): Warmup iterations for pair-wise loss. |
|
Defaults to 10000. |
|
""" |
|
|
|
def __init__(self, |
|
*arg, |
|
pairwise_size: int = 3, |
|
pairwise_dilation: int = 2, |
|
warmup_iters: int = 10000, |
|
**kwargs) -> None: |
|
self.pairwise_size = pairwise_size |
|
self.pairwise_dilation = pairwise_dilation |
|
self.warmup_iters = warmup_iters |
|
super().__init__(*arg, **kwargs) |
|
|
|
def get_pairwise_affinity(self, mask_logits: Tensor) -> Tensor: |
|
"""Compute the pairwise affinity for each pixel.""" |
|
log_fg_prob = F.logsigmoid(mask_logits).unsqueeze(1) |
|
log_bg_prob = F.logsigmoid(-mask_logits).unsqueeze(1) |
|
|
|
log_fg_prob_unfold = unfold_wo_center( |
|
log_fg_prob, |
|
kernel_size=self.pairwise_size, |
|
dilation=self.pairwise_dilation) |
|
log_bg_prob_unfold = unfold_wo_center( |
|
log_bg_prob, |
|
kernel_size=self.pairwise_size, |
|
dilation=self.pairwise_dilation) |
|
|
|
|
|
|
|
|
|
|
|
log_same_fg_prob = log_fg_prob[:, :, None] + log_fg_prob_unfold |
|
log_same_bg_prob = log_bg_prob[:, :, None] + log_bg_prob_unfold |
|
|
|
|
|
max_ = torch.max(log_same_fg_prob, log_same_bg_prob) |
|
log_same_prob = torch.log( |
|
torch.exp(log_same_fg_prob - max_) + |
|
torch.exp(log_same_bg_prob - max_)) + max_ |
|
|
|
return -log_same_prob[:, 0] |
|
|
|
def loss_by_feat(self, mask_preds: List[Tensor], |
|
batch_gt_instances: InstanceList, |
|
batch_img_metas: List[dict], positive_infos: InstanceList, |
|
**kwargs) -> dict: |
|
"""Calculate the loss based on the features extracted by the mask head. |
|
|
|
Args: |
|
mask_preds (list[Tensor]): List of predicted masks, each has |
|
shape (num_classes, H, W). |
|
batch_gt_instances (list[:obj:`InstanceData`]): Batch of |
|
gt_instance. It usually includes ``bboxes``, ``masks``, |
|
and ``labels`` attributes. |
|
batch_img_metas (list[dict]): Meta information of multiple images. |
|
positive_infos (List[:obj:``InstanceData``]): Information of |
|
positive samples of each image that are assigned in detection |
|
head. |
|
|
|
Returns: |
|
dict[str, Tensor]: A dictionary of loss components. |
|
""" |
|
assert positive_infos is not None, \ |
|
'positive_infos should not be None in `BoxInstMaskHead`' |
|
losses = dict() |
|
|
|
loss_mask_project = 0. |
|
loss_mask_pairwise = 0. |
|
num_imgs = len(mask_preds) |
|
total_pos = 0. |
|
avg_fatcor = 0. |
|
|
|
for idx in range(num_imgs): |
|
(mask_pred, pos_mask_targets, pos_pairwise_masks, num_pos) = \ |
|
self._get_targets_single( |
|
mask_preds[idx], batch_gt_instances[idx], |
|
positive_infos[idx]) |
|
|
|
total_pos += num_pos |
|
if num_pos == 0 or pos_mask_targets is None: |
|
loss_project = mask_pred.new_zeros(1).mean() |
|
loss_pairwise = mask_pred.new_zeros(1).mean() |
|
avg_fatcor += 0. |
|
else: |
|
|
|
loss_project_x = self.loss_mask( |
|
mask_pred.max(dim=1, keepdim=True)[0], |
|
pos_mask_targets.max(dim=1, keepdim=True)[0], |
|
reduction_override='none').sum() |
|
loss_project_y = self.loss_mask( |
|
mask_pred.max(dim=2, keepdim=True)[0], |
|
pos_mask_targets.max(dim=2, keepdim=True)[0], |
|
reduction_override='none').sum() |
|
loss_project = loss_project_x + loss_project_y |
|
|
|
pairwise_affinity = self.get_pairwise_affinity(mask_pred) |
|
avg_fatcor += pos_pairwise_masks.sum().clamp(min=1.0) |
|
loss_pairwise = (pairwise_affinity * pos_pairwise_masks).sum() |
|
|
|
loss_mask_project += loss_project |
|
loss_mask_pairwise += loss_pairwise |
|
|
|
if total_pos == 0: |
|
total_pos += 1 |
|
if avg_fatcor == 0: |
|
avg_fatcor += 1 |
|
loss_mask_project = loss_mask_project / total_pos |
|
loss_mask_pairwise = loss_mask_pairwise / avg_fatcor |
|
message_hub = MessageHub.get_current_instance() |
|
iter = message_hub.get_info('iter') |
|
warmup_factor = min(iter / float(self.warmup_iters), 1.0) |
|
loss_mask_pairwise *= warmup_factor |
|
|
|
losses.update( |
|
loss_mask_project=loss_mask_project, |
|
loss_mask_pairwise=loss_mask_pairwise) |
|
return losses |
|
|
|
def _get_targets_single(self, mask_preds: Tensor, |
|
gt_instances: InstanceData, |
|
positive_info: InstanceData): |
|
"""Compute targets for predictions of single image. |
|
|
|
Args: |
|
mask_preds (Tensor): Predicted prototypes with shape |
|
(num_classes, H, W). |
|
gt_instances (:obj:`InstanceData`): Ground truth of instance |
|
annotations. It should includes ``bboxes``, ``labels``, |
|
and ``masks`` attributes. |
|
positive_info (:obj:`InstanceData`): Information of positive |
|
samples that are assigned in detection head. It usually |
|
contains following keys. |
|
|
|
- pos_assigned_gt_inds (Tensor): Assigner GT indexes of |
|
positive proposals, has shape (num_pos, ) |
|
- pos_inds (Tensor): Positive index of image, has |
|
shape (num_pos, ). |
|
- param_pred (Tensor): Positive param preditions |
|
with shape (num_pos, num_params). |
|
|
|
Returns: |
|
tuple: Usually returns a tuple containing learning targets. |
|
|
|
- mask_preds (Tensor): Positive predicted mask with shape |
|
(num_pos, mask_h, mask_w). |
|
- pos_mask_targets (Tensor): Positive mask targets with shape |
|
(num_pos, mask_h, mask_w). |
|
- pos_pairwise_masks (Tensor): Positive pairwise masks with |
|
shape: (num_pos, num_neighborhood, mask_h, mask_w). |
|
- num_pos (int): Positive numbers. |
|
""" |
|
gt_bboxes = gt_instances.bboxes |
|
device = gt_bboxes.device |
|
|
|
|
|
gt_masks = gt_instances.masks.to_tensor( |
|
dtype=torch.bool, device=device).float() |
|
|
|
|
|
pairwise_masks = gt_instances.pairwise_masks |
|
pairwise_masks = pairwise_masks.to(device=device) |
|
|
|
|
|
pos_assigned_gt_inds = positive_info.get('pos_assigned_gt_inds') |
|
scores = positive_info.get('scores') |
|
centernesses = positive_info.get('centernesses') |
|
num_pos = pos_assigned_gt_inds.size(0) |
|
|
|
if gt_masks.size(0) == 0 or num_pos == 0: |
|
return mask_preds, None, None, 0 |
|
|
|
|
|
|
|
if (self.max_masks_to_train != -1) and \ |
|
(num_pos > self.max_masks_to_train): |
|
perm = torch.randperm(num_pos) |
|
select = perm[:self.max_masks_to_train] |
|
mask_preds = mask_preds[select] |
|
pos_assigned_gt_inds = pos_assigned_gt_inds[select] |
|
num_pos = self.max_masks_to_train |
|
elif self.topk_masks_per_img != -1: |
|
unique_gt_inds = pos_assigned_gt_inds.unique() |
|
num_inst_per_gt = max( |
|
int(self.topk_masks_per_img / len(unique_gt_inds)), 1) |
|
|
|
keep_mask_preds = [] |
|
keep_pos_assigned_gt_inds = [] |
|
for gt_ind in unique_gt_inds: |
|
per_inst_pos_inds = (pos_assigned_gt_inds == gt_ind) |
|
mask_preds_per_inst = mask_preds[per_inst_pos_inds] |
|
gt_inds_per_inst = pos_assigned_gt_inds[per_inst_pos_inds] |
|
if sum(per_inst_pos_inds) > num_inst_per_gt: |
|
per_inst_scores = scores[per_inst_pos_inds].sigmoid().max( |
|
dim=1)[0] |
|
per_inst_centerness = centernesses[ |
|
per_inst_pos_inds].sigmoid().reshape(-1, ) |
|
select = (per_inst_scores * per_inst_centerness).topk( |
|
k=num_inst_per_gt, dim=0)[1] |
|
mask_preds_per_inst = mask_preds_per_inst[select] |
|
gt_inds_per_inst = gt_inds_per_inst[select] |
|
keep_mask_preds.append(mask_preds_per_inst) |
|
keep_pos_assigned_gt_inds.append(gt_inds_per_inst) |
|
mask_preds = torch.cat(keep_mask_preds) |
|
pos_assigned_gt_inds = torch.cat(keep_pos_assigned_gt_inds) |
|
num_pos = pos_assigned_gt_inds.size(0) |
|
|
|
|
|
start = int(self.mask_out_stride // 2) |
|
gt_masks = gt_masks[:, start::self.mask_out_stride, |
|
start::self.mask_out_stride] |
|
gt_masks = gt_masks.gt(0.5).float() |
|
pos_mask_targets = gt_masks[pos_assigned_gt_inds] |
|
pos_pairwise_masks = pairwise_masks[pos_assigned_gt_inds] |
|
pos_pairwise_masks = pos_pairwise_masks * pos_mask_targets.unsqueeze(1) |
|
|
|
return (mask_preds, pos_mask_targets, pos_pairwise_masks, num_pos) |
|
|