# Copyright (c) OpenMMLab. All rights reserved. from typing import Optional import numpy as np from mmengine.structures import InstanceData from torch import Tensor from mmdet.registry import MODELS, TASK_UTILS from mmdet.structures import TrackSampleList from mmdet.utils import OptConfigType from .deep_sort import DeepSORT @MODELS.register_module() class StrongSORT(DeepSORT): """StrongSORT: Make DeepSORT Great Again. Details can be found at `StrongSORT`_. Args: detector (dict): Configuration of detector. Defaults to None. reid (dict): Configuration of reid. Defaults to None tracker (dict): Configuration of tracker. Defaults to None. kalman (dict): Configuration of Kalman filter. Defaults to None. cmc (dict): Configuration of camera model compensation. Defaults to None. data_preprocessor (dict or ConfigDict, optional): The pre-process config of :class:`TrackDataPreprocessor`. it usually includes, ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``. init_cfg (dict or list[dict]): Configuration of initialization. Defaults to None. """ def __init__(self, detector: Optional[dict] = None, reid: Optional[dict] = None, cmc: Optional[dict] = None, tracker: Optional[dict] = None, postprocess_model: Optional[dict] = None, data_preprocessor: OptConfigType = None, init_cfg: OptConfigType = None): super().__init__(detector, reid, tracker, data_preprocessor, init_cfg) if cmc is not None: self.cmc = TASK_UTILS.build(cmc) if postprocess_model is not None: self.postprocess_model = TASK_UTILS.build(postprocess_model) @property def with_cmc(self): """bool: whether the framework has a camera model compensation model. """ return hasattr(self, 'cmc') and self.cmc is not None def predict(self, inputs: Tensor, data_samples: TrackSampleList, rescale: bool = True, **kwargs) -> TrackSampleList: """Predict results from a video and data samples with post- processing. Args: inputs (Tensor): of shape (N, T, C, H, W) encoding input images. The N denotes batch size. The T denotes the number of key frames and reference frames. data_samples (list[:obj:`TrackDataSample`]): The batch data samples. It usually includes information such as `gt_instance`. rescale (bool, Optional): If False, then returned bboxes and masks will fit the scale of img, otherwise, returned bboxes and masks will fit the scale of original image shape. Defaults to True. Returns: TrackSampleList: List[TrackDataSample] Tracking results of the input videos. Each DetDataSample usually contains ``pred_track_instances``. """ assert inputs.dim() == 5, 'The img must be 5D Tensor (N, T, C, H, W).' assert inputs.size(0) == 1, \ 'SORT/DeepSORT inference only support ' \ '1 batch size per gpu for now.' assert len(data_samples) == 1, \ 'SORT/DeepSORT inference only support ' \ '1 batch size per gpu for now.' track_data_sample = data_samples[0] video_len = len(track_data_sample) video_track_instances = [] for frame_id in range(video_len): img_data_sample = track_data_sample[frame_id] single_img = inputs[:, frame_id].contiguous() # det_results List[DetDataSample] det_results = self.detector.predict(single_img, [img_data_sample]) assert len(det_results) == 1, 'Batch inference is not supported.' pred_track_instances = self.tracker.track( model=self, img=single_img, data_sample=det_results[0], data_preprocessor=self.preprocess_cfg, rescale=rescale, **kwargs) for i in range(len(pred_track_instances.instances_id)): video_track_instances.append( np.array([ frame_id + 1, pred_track_instances.instances_id[i].cpu(), pred_track_instances.bboxes[i][0].cpu(), pred_track_instances.bboxes[i][1].cpu(), (pred_track_instances.bboxes[i][2] - pred_track_instances.bboxes[i][0]).cpu(), (pred_track_instances.bboxes[i][3] - pred_track_instances.bboxes[i][1]).cpu(), pred_track_instances.scores[i].cpu() ])) video_track_instances = np.array(video_track_instances).reshape(-1, 7) video_track_instances = self.postprocess_model.forward( video_track_instances) for frame_id in range(video_len): track_data_sample[frame_id].pred_track_instances = \ InstanceData(bboxes=video_track_instances[ video_track_instances[:, 0] == frame_id + 1, :]) return [track_data_sample]