from __future__ import annotations from typing import Type import numpy as np import torch from torch import Tensor from ..image_utils import as_3d from ..onnx.np_tensor_utils import MAX_VALUES_BY_DTYPE, np_denorm def bgr_to_rgb(image: Tensor) -> Tensor: # flip image channels # https://github.com/pytorch/pytorch/issues/229 out: Tensor = image.flip(-3) # RGB to BGR #may be faster: # out: Tensor = image[[2, 1, 0], :, :] return out def rgb_to_bgr(image: Tensor) -> Tensor: # same operation as bgr_to_rgb(), flip image channels return bgr_to_rgb(image) def bgra_to_rgba(image: Tensor) -> Tensor: out: Tensor = image[[2, 1, 0, 3], :, :] return out def rgba_to_bgra(image: Tensor) -> Tensor: # same operation as bgra_to_rgba(), flip image channels return bgra_to_rgba(image) def norm(x: Tensor): """Normalize (z-norm) from [0,1] range to [-1,1]""" out = (x - 0.5) * 2.0 return out.clamp(-1, 1) def np2tensor( img: np.ndarray, bgr2rgb=True, data_range=1.0, # pylint: disable=unused-argument normalize=False, change_range=True, add_batch=True, ) -> Tensor: """Converts a numpy image array into a Tensor array. Parameters: img (numpy array): the input image numpy array add_batch (bool): choose if new tensor needs batch dimension added """ # check how many channels the image has, then condition. ie. RGB, RGBA, Gray # if bgr2rgb: # img = img[ # :, :, [2, 1, 0] # ] # BGR to RGB -> in numpy, if using OpenCV, else not needed. Only if image has colors. if change_range: dtype = img.dtype maxval = MAX_VALUES_BY_DTYPE.get(dtype.name, 1.0) t_dtype = np.dtype("float32") img = img.astype(t_dtype) / maxval # ie: uint8 = /255 # "HWC to CHW" and "numpy to tensor" tensor = torch.from_numpy( np.ascontiguousarray(np.transpose(as_3d(img), (2, 0, 1))) ).float() if bgr2rgb: # BGR to RGB -> in tensor, if using OpenCV, else not needed. Only if image has colors.) if tensor.shape[0] % 3 == 0: # RGB or MultixRGB (3xRGB, 5xRGB, etc. For video tensors.) tensor = bgr_to_rgb(tensor) elif tensor.shape[0] == 4: # RGBA tensor = bgra_to_rgba(tensor) if add_batch: # Add fake batch dimension = 1 . squeeze() will remove the dimensions of size 1 tensor.unsqueeze_(0) if normalize: tensor = norm(tensor) return tensor def tensor2np( img: Tensor, rgb2bgr=True, remove_batch=True, data_range=255, denormalize=False, change_range=True, imtype: Type = np.uint8, ) -> np.ndarray: """Converts a Tensor array into a numpy image array. Parameters: img (tensor): the input image tensor array 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order remove_batch (bool): choose if tensor of shape BCHW needs to be squeezed denormalize (bool): Used to denormalize from [-1,1] range back to [0,1] imtype (type): the desired type of the converted numpy array (np.uint8 default) Output: img (np array): 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) """ n_dim = img.dim() img = img.float().cpu() img_np: np.ndarray if n_dim in (4, 3): # if n_dim == 4, has to convert to 3 dimensions if n_dim == 4 and remove_batch: # remove a fake batch dimension img = img.squeeze(dim=0) if img.shape[0] == 3 and rgb2bgr: # RGB # RGB to BGR -> in tensor, if using OpenCV, else not needed. Only if image has colors. img_np = rgb_to_bgr(img).numpy() elif img.shape[0] == 4 and rgb2bgr: # RGBA # RGBA to BGRA -> in tensor, if using OpenCV, else not needed. Only if image has colors. img_np = rgba_to_bgra(img).numpy() else: img_np = img.numpy() img_np = np.transpose(img_np, (1, 2, 0)) # CHW to HWC elif n_dim == 2: img_np = img.numpy() else: raise TypeError( f"Only support 4D, 3D and 2D tensor. But received with dimension: {n_dim:d}" ) # if rgb2bgr: # img_np = img_np[[2, 1, 0], :, :] #RGB to BGR -> in numpy, if using OpenCV, else not needed. Only if image has colors. if denormalize: img_np = np_denorm(img_np) # denormalize if needed if change_range: img_np = np.clip( data_range * img_np, 0, data_range ).round() # np.clip to the data_range # has to be in range (0,255) before changing to np.uint8, else np.float32 return img_np.astype(imtype) def safe_cuda_cache_empty(): """ Empties the CUDA cache if CUDA is available. Hopefully without causing any errors. """ try: if torch.cuda.is_available(): torch.cuda.empty_cache() except: pass