Spaces:
Runtime error
Runtime error
File size: 12,757 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmengine.structures import BaseDataElement
from mmdet.models.utils import multi_apply
from mmdet.registry import MODELS, TASK_UTILS
from mmdet.utils import reduce_mean
class DDQAuxLoss(nn.Module):
"""DDQ auxiliary branches loss for dense queries.
Args:
loss_cls (dict):
Configuration of classification loss function.
loss_bbox (dict):
Configuration of bbox regression loss function.
train_cfg (dict):
Configuration of gt targets assigner for each predicted bbox.
"""
def __init__(
self,
loss_cls=dict(
type='QualityFocalLoss',
use_sigmoid=True,
activated=True, # use probability instead of logit as input
beta=2.0,
loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=2.0),
train_cfg=dict(
assigner=dict(type='TopkHungarianAssigner', topk=8),
alpha=1,
beta=6),
):
super(DDQAuxLoss, self).__init__()
self.train_cfg = train_cfg
self.loss_cls = MODELS.build(loss_cls)
self.loss_bbox = MODELS.build(loss_bbox)
self.assigner = TASK_UTILS.build(self.train_cfg['assigner'])
sampler_cfg = dict(type='PseudoSampler')
self.sampler = TASK_UTILS.build(sampler_cfg)
def loss_single(self, cls_score, bbox_pred, labels, label_weights,
bbox_targets, alignment_metrics):
"""Calculate auxiliary branches loss for dense queries for one image.
Args:
cls_score (Tensor): Predicted normalized classification
scores for one image, has shape (num_dense_queries,
cls_out_channels).
bbox_pred (Tensor): Predicted unnormalized bbox coordinates
for one image, has shape (num_dense_queries, 4) with the
last dimension arranged as (x1, y1, x2, y2).
labels (Tensor): Labels for one image.
label_weights (Tensor): Label weights for one image.
bbox_targets (Tensor): Bbox targets for one image.
alignment_metrics (Tensor): Normalized alignment metrics for one
image.
Returns:
tuple: A tuple of loss components and loss weights.
"""
bbox_targets = bbox_targets.reshape(-1, 4)
labels = labels.reshape(-1)
alignment_metrics = alignment_metrics.reshape(-1)
label_weights = label_weights.reshape(-1)
targets = (labels, alignment_metrics)
cls_loss_func = self.loss_cls
loss_cls = cls_loss_func(
cls_score, targets, label_weights, avg_factor=1.0)
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes
bg_class_ind = cls_score.size(-1)
pos_inds = ((labels >= 0)
& (labels < bg_class_ind)).nonzero().squeeze(1)
if len(pos_inds) > 0:
pos_bbox_targets = bbox_targets[pos_inds]
pos_bbox_pred = bbox_pred[pos_inds]
pos_decode_bbox_pred = pos_bbox_pred
pos_decode_bbox_targets = pos_bbox_targets
# regression loss
pos_bbox_weight = alignment_metrics[pos_inds]
loss_bbox = self.loss_bbox(
pos_decode_bbox_pred,
pos_decode_bbox_targets,
weight=pos_bbox_weight,
avg_factor=1.0)
else:
loss_bbox = bbox_pred.sum() * 0
pos_bbox_weight = bbox_targets.new_tensor(0.)
return loss_cls, loss_bbox, alignment_metrics.sum(
), pos_bbox_weight.sum()
def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas,
**kwargs):
"""Calculate auxiliary branches loss for dense queries.
Args:
cls_scores (Tensor): Predicted normalized classification
scores, has shape (bs, num_dense_queries,
cls_out_channels).
bbox_preds (Tensor): Predicted unnormalized bbox coordinates,
has shape (bs, num_dense_queries, 4) with the last
dimension arranged as (x1, y1, x2, y2).
gt_bboxes (list[Tensor]): List of unnormalized ground truth
bboxes for each image, each has shape (num_gt, 4) with the
last dimension arranged as (x1, y1, x2, y2).
NOTE: num_gt is dynamic for each image.
gt_labels (list[Tensor]): List of ground truth classification
index for each image, each has shape (num_gt,).
NOTE: num_gt is dynamic for each image.
img_metas (list[dict]): Meta information for one image,
e.g., image size, scaling factor, etc.
Returns:
dict: A dictionary of loss components.
"""
flatten_cls_scores = cls_scores
flatten_bbox_preds = bbox_preds
cls_reg_targets = self.get_targets(
flatten_cls_scores,
flatten_bbox_preds,
gt_bboxes,
img_metas,
gt_labels_list=gt_labels,
)
(labels_list, label_weights_list, bbox_targets_list,
alignment_metrics_list) = cls_reg_targets
losses_cls, losses_bbox, \
cls_avg_factors, bbox_avg_factors = multi_apply(
self.loss_single,
flatten_cls_scores,
flatten_bbox_preds,
labels_list,
label_weights_list,
bbox_targets_list,
alignment_metrics_list,
)
cls_avg_factor = reduce_mean(sum(cls_avg_factors)).clamp_(min=1).item()
losses_cls = list(map(lambda x: x / cls_avg_factor, losses_cls))
bbox_avg_factor = reduce_mean(
sum(bbox_avg_factors)).clamp_(min=1).item()
losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox))
return dict(aux_loss_cls=losses_cls, aux_loss_bbox=losses_bbox)
def get_targets(self,
cls_scores,
bbox_preds,
gt_bboxes_list,
img_metas,
gt_labels_list=None,
**kwargs):
"""Compute regression and classification targets for a batch images.
Args:
cls_scores (Tensor): Predicted normalized classification
scores, has shape (bs, num_dense_queries,
cls_out_channels).
bbox_preds (Tensor): Predicted unnormalized bbox coordinates,
has shape (bs, num_dense_queries, 4) with the last
dimension arranged as (x1, y1, x2, y2).
gt_bboxes_list (List[Tensor]): List of unnormalized ground truth
bboxes for each image, each has shape (num_gt, 4) with the
last dimension arranged as (x1, y1, x2, y2).
NOTE: num_gt is dynamic for each image.
img_metas (list[dict]): Meta information for one image,
e.g., image size, scaling factor, etc.
gt_labels_list (list[Tensor]): List of ground truth classification
index for each image, each has shape (num_gt,).
NOTE: num_gt is dynamic for each image.
Default: None.
Returns:
tuple: a tuple containing the following targets.
- all_labels (list[Tensor]): Labels for all images.
- all_label_weights (list[Tensor]): Label weights for all images.
- all_bbox_targets (list[Tensor]): Bbox targets for all images.
- all_assign_metrics (list[Tensor]): Normalized alignment metrics
for all images.
"""
(all_labels, all_label_weights, all_bbox_targets,
all_assign_metrics) = multi_apply(self._get_target_single, cls_scores,
bbox_preds, gt_bboxes_list,
gt_labels_list, img_metas)
return (all_labels, all_label_weights, all_bbox_targets,
all_assign_metrics)
def _get_target_single(self, cls_scores, bbox_preds, gt_bboxes, gt_labels,
img_meta, **kwargs):
"""Compute regression and classification targets for one image.
Args:
cls_scores (Tensor): Predicted normalized classification
scores for one image, has shape (num_dense_queries,
cls_out_channels).
bbox_preds (Tensor): Predicted unnormalized bbox coordinates
for one image, has shape (num_dense_queries, 4) with the
last dimension arranged as (x1, y1, x2, y2).
gt_bboxes (Tensor): Unnormalized ground truth
bboxes for one image, has shape (num_gt, 4) with the
last dimension arranged as (x1, y1, x2, y2).
NOTE: num_gt is dynamic for each image.
gt_labels (Tensor): Ground truth classification
index for the image, has shape (num_gt,).
NOTE: num_gt is dynamic for each image.
img_meta (dict): Meta information for one image.
Returns:
tuple[Tensor]: a tuple containing the following for one image.
- labels (Tensor): Labels for one image.
- label_weights (Tensor): Label weights for one image.
- bbox_targets (Tensor): Bbox targets for one image.
- norm_alignment_metrics (Tensor): Normalized alignment
metrics for one image.
"""
if len(gt_labels) == 0:
num_valid_anchors = len(cls_scores)
bbox_targets = torch.zeros_like(bbox_preds)
labels = bbox_preds.new_full((num_valid_anchors, ),
cls_scores.size(-1),
dtype=torch.long)
label_weights = bbox_preds.new_zeros(
num_valid_anchors, dtype=torch.float)
norm_alignment_metrics = bbox_preds.new_zeros(
num_valid_anchors, dtype=torch.float)
return (labels, label_weights, bbox_targets,
norm_alignment_metrics)
assign_result = self.assigner.assign(cls_scores, bbox_preds, gt_bboxes,
gt_labels, img_meta)
assign_ious = assign_result.max_overlaps
assign_metrics = assign_result.assign_metrics
pred_instances = BaseDataElement()
gt_instances = BaseDataElement()
pred_instances.bboxes = bbox_preds
gt_instances.bboxes = gt_bboxes
pred_instances.priors = cls_scores
gt_instances.labels = gt_labels
sampling_result = self.sampler.sample(assign_result, pred_instances,
gt_instances)
num_valid_anchors = len(cls_scores)
bbox_targets = torch.zeros_like(bbox_preds)
labels = bbox_preds.new_full((num_valid_anchors, ),
cls_scores.size(-1),
dtype=torch.long)
label_weights = bbox_preds.new_zeros(
num_valid_anchors, dtype=torch.float)
norm_alignment_metrics = bbox_preds.new_zeros(
num_valid_anchors, dtype=torch.float)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
if len(pos_inds) > 0:
# point-based
pos_bbox_targets = sampling_result.pos_gt_bboxes
bbox_targets[pos_inds, :] = pos_bbox_targets
if gt_labels is None:
# Only dense_heads gives gt_labels as None
# Foreground is the first class since v2.5.0
labels[pos_inds] = 0
else:
labels[pos_inds] = gt_labels[
sampling_result.pos_assigned_gt_inds]
label_weights[pos_inds] = 1.0
if len(neg_inds) > 0:
label_weights[neg_inds] = 1.0
class_assigned_gt_inds = torch.unique(
sampling_result.pos_assigned_gt_inds)
for gt_inds in class_assigned_gt_inds:
gt_class_inds = sampling_result.pos_assigned_gt_inds == gt_inds
pos_alignment_metrics = assign_metrics[gt_class_inds]
pos_ious = assign_ious[gt_class_inds]
pos_norm_alignment_metrics = pos_alignment_metrics / (
pos_alignment_metrics.max() + 10e-8) * pos_ious.max()
norm_alignment_metrics[
pos_inds[gt_class_inds]] = pos_norm_alignment_metrics
return (labels, label_weights, bbox_targets, norm_alignment_metrics)
|