Spaces:
Runtime error
Runtime error
File size: 14,673 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule
from torch import Tensor
from torch.nn.modules.utils import _pair
from mmdet.models.task_modules import SamplingResult
from mmdet.registry import MODELS
from ..task_modules.tracking import embed_similarity
@MODELS.register_module()
class QuasiDenseEmbedHead(BaseModule):
"""The quasi-dense roi embed head.
Args:
embed_channels (int): The input channel of embed features.
Defaults to 256.
softmax_temp (int): Softmax temperature. Defaults to -1.
loss_track (dict): The loss function for tracking. Defaults to
MultiPosCrossEntropyLoss.
loss_track_aux (dict): The auxiliary loss function for tracking.
Defaults to MarginL2Loss.
init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
dict]): Initialization config dict.
"""
def __init__(self,
num_convs: int = 0,
num_fcs: int = 0,
roi_feat_size: int = 7,
in_channels: int = 256,
conv_out_channels: int = 256,
with_avg_pool: bool = False,
fc_out_channels: int = 1024,
conv_cfg: Optional[dict] = None,
norm_cfg: Optional[dict] = None,
embed_channels: int = 256,
softmax_temp: int = -1,
loss_track: Optional[dict] = None,
loss_track_aux: dict = dict(
type='MarginL2Loss',
sample_ratio=3,
margin=0.3,
loss_weight=1.0,
hard_mining=True),
init_cfg: dict = dict(
type='Xavier',
layer='Linear',
distribution='uniform',
bias=0,
override=dict(
type='Normal',
name='fc_embed',
mean=0,
std=0.01,
bias=0))):
super(QuasiDenseEmbedHead, self).__init__(init_cfg=init_cfg)
self.num_convs = num_convs
self.num_fcs = num_fcs
self.roi_feat_size = _pair(roi_feat_size)
self.roi_feat_area = self.roi_feat_size[0] * self.roi_feat_size[1]
self.in_channels = in_channels
self.conv_out_channels = conv_out_channels
self.with_avg_pool = with_avg_pool
self.fc_out_channels = fc_out_channels
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
if self.with_avg_pool:
self.avg_pool = nn.AvgPool2d(self.roi_feat_size)
# add convs and fcs
self.convs, self.fcs, self.last_layer_dim = self._add_conv_fc_branch(
self.num_convs, self.num_fcs, self.in_channels)
self.relu = nn.ReLU(inplace=True)
if loss_track is None:
loss_track = dict(
type='MultiPosCrossEntropyLoss', loss_weight=0.25)
self.fc_embed = nn.Linear(self.last_layer_dim, embed_channels)
self.softmax_temp = softmax_temp
self.loss_track = MODELS.build(loss_track)
if loss_track_aux is not None:
self.loss_track_aux = MODELS.build(loss_track_aux)
else:
self.loss_track_aux = None
def _add_conv_fc_branch(
self, num_branch_convs: int, num_branch_fcs: int,
in_channels: int) -> Tuple[nn.ModuleList, nn.ModuleList, int]:
"""Add shared or separable branch. convs -> avg pool (optional) -> fcs.
Args:
num_branch_convs (int): The number of convoluational layers.
num_branch_fcs (int): The number of fully connection layers.
in_channels (int): The input channel of roi features.
Returns:
Tuple[nn.ModuleList, nn.ModuleList, int]: The convs, fcs and the
last layer dimension.
"""
last_layer_dim = in_channels
# add branch specific conv layers
branch_convs = nn.ModuleList()
if num_branch_convs > 0:
for i in range(num_branch_convs):
conv_in_channels = (
last_layer_dim if i == 0 else self.conv_out_channels)
branch_convs.append(
ConvModule(
conv_in_channels,
self.conv_out_channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
last_layer_dim = self.conv_out_channels
# add branch specific fc layers
branch_fcs = nn.ModuleList()
if num_branch_fcs > 0:
if not self.with_avg_pool:
last_layer_dim *= self.roi_feat_area
for i in range(num_branch_fcs):
fc_in_channels = (
last_layer_dim if i == 0 else self.fc_out_channels)
branch_fcs.append(
nn.Linear(fc_in_channels, self.fc_out_channels))
last_layer_dim = self.fc_out_channels
return branch_convs, branch_fcs, last_layer_dim
def forward(self, x: Tensor) -> Tensor:
"""Forward function.
Args:
x (Tensor): The input features from ROI head.
Returns:
Tensor: The embedding feature map.
"""
if self.num_convs > 0:
for conv in self.convs:
x = conv(x)
x = x.flatten(1)
if self.num_fcs > 0:
for fc in self.fcs:
x = self.relu(fc(x))
x = self.fc_embed(x)
return x
def get_targets(
self, gt_match_indices: List[Tensor],
key_sampling_results: List[SamplingResult],
ref_sampling_results: List[SamplingResult]) -> Tuple[List, List]:
"""Calculate the track targets and track weights for all samples in a
batch according to the sampling_results.
Args:
gt_match_indices (list(Tensor)): Mapping from gt_instance_ids to
ref_gt_instance_ids of the same tracklet in a pair of images.
key_sampling_results (List[obj:SamplingResult]): Assign results of
all images in a batch after sampling.
ref_sampling_results (List[obj:SamplingResult]): Assign results of
all reference images in a batch after sampling.
Returns:
Tuple[list[Tensor]]: Association results.
Containing the following list of Tensors:
- track_targets (list[Tensor]): The mapping instance ids from
all positive proposals in the key image to all proposals
in the reference image, each tensor in list has
shape (len(key_pos_bboxes), len(ref_bboxes)).
- track_weights (list[Tensor]): Loss weights for all positive
proposals in a batch, each tensor in list has
shape (len(key_pos_bboxes),).
"""
track_targets = []
track_weights = []
for _gt_match_indices, key_res, ref_res in zip(gt_match_indices,
key_sampling_results,
ref_sampling_results):
targets = _gt_match_indices.new_zeros(
(key_res.pos_bboxes.size(0), ref_res.bboxes.size(0)),
dtype=torch.int)
_match_indices = _gt_match_indices[key_res.pos_assigned_gt_inds]
pos2pos = (_match_indices.view(
-1, 1) == ref_res.pos_assigned_gt_inds.view(1, -1)).int()
targets[:, :pos2pos.size(1)] = pos2pos
weights = (targets.sum(dim=1) > 0).float()
track_targets.append(targets)
track_weights.append(weights)
return track_targets, track_weights
def match(
self, key_embeds: Tensor, ref_embeds: Tensor,
key_sampling_results: List[SamplingResult],
ref_sampling_results: List[SamplingResult]
) -> Tuple[List[Tensor], List[Tensor]]:
"""Calculate the dist matrixes for loss measurement.
Args:
key_embeds (Tensor): Embeds of positive bboxes in sampling results
of key image.
ref_embeds (Tensor): Embeds of all bboxes in sampling results
of the reference image.
key_sampling_results (List[obj:SamplingResults]): Assign results of
all images in a batch after sampling.
ref_sampling_results (List[obj:SamplingResults]): Assign results of
all reference images in a batch after sampling.
Returns:
Tuple[list[Tensor]]: Calculation results.
Containing the following list of Tensors:
- dists (list[Tensor]): Dot-product dists between
key_embeds and ref_embeds, each tensor in list has
shape (len(key_pos_bboxes), len(ref_bboxes)).
- cos_dists (list[Tensor]): Cosine dists between
key_embeds and ref_embeds, each tensor in list has
shape (len(key_pos_bboxes), len(ref_bboxes)).
"""
num_key_rois = [res.pos_bboxes.size(0) for res in key_sampling_results]
key_embeds = torch.split(key_embeds, num_key_rois)
num_ref_rois = [res.bboxes.size(0) for res in ref_sampling_results]
ref_embeds = torch.split(ref_embeds, num_ref_rois)
dists, cos_dists = [], []
for key_embed, ref_embed in zip(key_embeds, ref_embeds):
dist = embed_similarity(
key_embed,
ref_embed,
method='dot_product',
temperature=self.softmax_temp)
dists.append(dist)
if self.loss_track_aux is not None:
cos_dist = embed_similarity(
key_embed, ref_embed, method='cosine')
cos_dists.append(cos_dist)
else:
cos_dists.append(None)
return dists, cos_dists
def loss(self, key_roi_feats: Tensor, ref_roi_feats: Tensor,
key_sampling_results: List[SamplingResult],
ref_sampling_results: List[SamplingResult],
gt_match_indices_list: List[Tensor]) -> dict:
"""Calculate the track loss and the auxiliary track loss.
Args:
key_roi_feats (Tensor): Embeds of positive bboxes in sampling
results of key image.
ref_roi_feats (Tensor): Embeds of all bboxes in sampling results
of the reference image.
key_sampling_results (List[obj:SamplingResults]): Assign results of
all images in a batch after sampling.
ref_sampling_results (List[obj:SamplingResults]): Assign results of
all reference images in a batch after sampling.
gt_match_indices_list (list(Tensor)): Mapping from gt_instances_ids
to ref_gt_instances_ids of the same tracklet in a pair of
images.
Returns:
Dict [str: Tensor]: Calculation results.
Containing the following list of Tensors:
- loss_track (Tensor): Results of loss_track function.
- loss_track_aux (Tensor): Results of loss_track_aux function.
"""
key_track_feats = self(key_roi_feats)
ref_track_feats = self(ref_roi_feats)
losses = self.loss_by_feat(key_track_feats, ref_track_feats,
key_sampling_results, ref_sampling_results,
gt_match_indices_list)
return losses
def loss_by_feat(self, key_track_feats: Tensor, ref_track_feats: Tensor,
key_sampling_results: List[SamplingResult],
ref_sampling_results: List[SamplingResult],
gt_match_indices_list: List[Tensor]) -> dict:
"""Calculate the track loss and the auxiliary track loss.
Args:
key_track_feats (Tensor): Embeds of positive bboxes in sampling
results of key image.
ref_track_feats (Tensor): Embeds of all bboxes in sampling results
of the reference image.
key_sampling_results (List[obj:SamplingResults]): Assign results of
all images in a batch after sampling.
ref_sampling_results (List[obj:SamplingResults]): Assign results of
all reference images in a batch after sampling.
gt_match_indices_list (list(Tensor)): Mapping from instances_ids
from key image to reference image of the same tracklet in a
pair of images.
Returns:
Dict [str: Tensor]: Calculation results.
Containing the following list of Tensors:
- loss_track (Tensor): Results of loss_track function.
- loss_track_aux (Tensor): Results of loss_track_aux function.
"""
dists, cos_dists = self.match(key_track_feats, ref_track_feats,
key_sampling_results,
ref_sampling_results)
targets, weights = self.get_targets(gt_match_indices_list,
key_sampling_results,
ref_sampling_results)
losses = dict()
loss_track = 0.
loss_track_aux = 0.
for _dists, _cos_dists, _targets, _weights in zip(
dists, cos_dists, targets, weights):
loss_track += self.loss_track(
_dists, _targets, _weights, avg_factor=_weights.sum())
if self.loss_track_aux is not None:
loss_track_aux += self.loss_track_aux(_cos_dists, _targets)
losses['loss_track'] = loss_track / len(dists)
if self.loss_track_aux is not None:
losses['loss_track_aux'] = loss_track_aux / len(dists)
return losses
def predict(self, bbox_feats: Tensor) -> Tensor:
"""Perform forward propagation of the tracking head and predict
tracking results on the features of the upstream network.
Args:
bbox_feats: The extracted roi features.
Returns:
Tensor: The extracted track features.
"""
track_feats = self(bbox_feats)
return track_feats
|