|
import random |
|
from typing import Optional, Tuple, Union |
|
|
|
import torch |
|
import numpy as np |
|
import mmcv |
|
from mmcv.transforms import to_tensor |
|
from mmcv.transforms.base import BaseTransform |
|
from mmengine.structures import InstanceData, PixelData |
|
from mmdet.registry import TRANSFORMS |
|
from mmdet.structures import DetDataSample |
|
from mmdet.structures.bbox import BaseBoxes |
|
import mmengine.fileio as fileio |
|
import cv2 |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@TRANSFORMS.register_module() |
|
class LoadSIRSTImageFromFiles(BaseTransform): |
|
"""Load multi-channel images from a list of separate channel files. |
|
|
|
Required Keys: |
|
|
|
- img_path |
|
|
|
Modified Keys: |
|
|
|
- img |
|
- img_shape |
|
- ori_shape |
|
|
|
Args: |
|
to_float32 (bool): Whether to convert the loaded image to a float32 |
|
numpy array. If set to False, the loaded image is an uint8 array. |
|
Defaults to False. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
to_float32: bool = False, |
|
normalized_basis = None, |
|
range_change = False, |
|
) -> None: |
|
self.to_float32 = to_float32 |
|
self.normalized_basis = normalized_basis |
|
self.range_change = range_change |
|
def transform(self, results: dict) -> dict: |
|
"""Transform functions to load multiple images and get images meta |
|
information. |
|
|
|
Args: |
|
results (dict): Result dict from :obj:`mmdet.CustomDataset`. |
|
|
|
Returns: |
|
dict: The dict contains loaded images and meta information. |
|
""" |
|
|
|
|
|
img = plt.imread(results['img_path']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(img.shape) == 2: |
|
img = np.repeat(np.expand_dims(img, axis=-1), 3, axis=-1) |
|
|
|
if img.shape[-1] !=3: |
|
img = img[:,:,:3] |
|
if self.to_float32: |
|
img = img.astype(np.float32) |
|
img = (img-np.min(img))/(np.max(img)-np.min(img)+1e-8) |
|
if self.range_change ==True: |
|
img = (0.7+random.random()*0.3)*img |
|
|
|
results['img'] = img |
|
results['img_shape'] = img.shape[:2] |
|
results['ori_shape'] = img.shape[:2] |
|
return results |
|
|
|
def __repr__(self): |
|
repr_str = (f'{self.__class__.__name__}(' |
|
f'to_float32={self.to_float32}, ') |
|
return repr_str |
|
|
|
|
|
|
|
|
|
|
|
|