TTP / mmdet /models /mot /strongsort.py
KyanChen's picture
Upload 1861 files
3b96cb1
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import numpy as np
from mmengine.structures import InstanceData
from torch import Tensor
from mmdet.registry import MODELS, TASK_UTILS
from mmdet.structures import TrackSampleList
from mmdet.utils import OptConfigType
from .deep_sort import DeepSORT
@MODELS.register_module()
class StrongSORT(DeepSORT):
"""StrongSORT: Make DeepSORT Great Again.
Details can be found at `StrongSORT<https://arxiv.org/abs/2202.13514>`_.
Args:
detector (dict): Configuration of detector. Defaults to None.
reid (dict): Configuration of reid. Defaults to None
tracker (dict): Configuration of tracker. Defaults to None.
kalman (dict): Configuration of Kalman filter. Defaults to None.
cmc (dict): Configuration of camera model compensation.
Defaults to None.
data_preprocessor (dict or ConfigDict, optional): The pre-process
config of :class:`TrackDataPreprocessor`. it usually includes,
``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
init_cfg (dict or list[dict]): Configuration of initialization.
Defaults to None.
"""
def __init__(self,
detector: Optional[dict] = None,
reid: Optional[dict] = None,
cmc: Optional[dict] = None,
tracker: Optional[dict] = None,
postprocess_model: Optional[dict] = None,
data_preprocessor: OptConfigType = None,
init_cfg: OptConfigType = None):
super().__init__(detector, reid, tracker, data_preprocessor, init_cfg)
if cmc is not None:
self.cmc = TASK_UTILS.build(cmc)
if postprocess_model is not None:
self.postprocess_model = TASK_UTILS.build(postprocess_model)
@property
def with_cmc(self):
"""bool: whether the framework has a camera model compensation
model.
"""
return hasattr(self, 'cmc') and self.cmc is not None
def predict(self,
inputs: Tensor,
data_samples: TrackSampleList,
rescale: bool = True,
**kwargs) -> TrackSampleList:
"""Predict results from a video and data samples with post- processing.
Args:
inputs (Tensor): of shape (N, T, C, H, W) encoding
input images. The N denotes batch size.
The T denotes the number of key frames
and reference frames.
data_samples (list[:obj:`TrackDataSample`]): The batch
data samples. It usually includes information such
as `gt_instance`.
rescale (bool, Optional): If False, then returned bboxes and masks
will fit the scale of img, otherwise, returned bboxes and masks
will fit the scale of original image shape. Defaults to True.
Returns:
TrackSampleList: List[TrackDataSample]
Tracking results of the input videos.
Each DetDataSample usually contains ``pred_track_instances``.
"""
assert inputs.dim() == 5, 'The img must be 5D Tensor (N, T, C, H, W).'
assert inputs.size(0) == 1, \
'SORT/DeepSORT inference only support ' \
'1 batch size per gpu for now.'
assert len(data_samples) == 1, \
'SORT/DeepSORT inference only support ' \
'1 batch size per gpu for now.'
track_data_sample = data_samples[0]
video_len = len(track_data_sample)
video_track_instances = []
for frame_id in range(video_len):
img_data_sample = track_data_sample[frame_id]
single_img = inputs[:, frame_id].contiguous()
# det_results List[DetDataSample]
det_results = self.detector.predict(single_img, [img_data_sample])
assert len(det_results) == 1, 'Batch inference is not supported.'
pred_track_instances = self.tracker.track(
model=self,
img=single_img,
data_sample=det_results[0],
data_preprocessor=self.preprocess_cfg,
rescale=rescale,
**kwargs)
for i in range(len(pred_track_instances.instances_id)):
video_track_instances.append(
np.array([
frame_id + 1,
pred_track_instances.instances_id[i].cpu(),
pred_track_instances.bboxes[i][0].cpu(),
pred_track_instances.bboxes[i][1].cpu(),
(pred_track_instances.bboxes[i][2] -
pred_track_instances.bboxes[i][0]).cpu(),
(pred_track_instances.bboxes[i][3] -
pred_track_instances.bboxes[i][1]).cpu(),
pred_track_instances.scores[i].cpu()
]))
video_track_instances = np.array(video_track_instances).reshape(-1, 7)
video_track_instances = self.postprocess_model.forward(
video_track_instances)
for frame_id in range(video_len):
track_data_sample[frame_id].pred_track_instances = \
InstanceData(bboxes=video_track_instances[
video_track_instances[:, 0] == frame_id + 1, :])
return [track_data_sample]