bilegentile's picture
Upload folder using huggingface_hub
c19ca42 verified
raw
history blame contribute delete
4.96 kB
from __future__ import annotations
from typing import Type
import numpy as np
import torch
from torch import Tensor
from ..image_utils import as_3d
from ..onnx.np_tensor_utils import MAX_VALUES_BY_DTYPE, np_denorm
def bgr_to_rgb(image: Tensor) -> Tensor:
# flip image channels
# https://github.com/pytorch/pytorch/issues/229
out: Tensor = image.flip(-3)
# RGB to BGR #may be faster:
# out: Tensor = image[[2, 1, 0], :, :]
return out
def rgb_to_bgr(image: Tensor) -> Tensor:
# same operation as bgr_to_rgb(), flip image channels
return bgr_to_rgb(image)
def bgra_to_rgba(image: Tensor) -> Tensor:
out: Tensor = image[[2, 1, 0, 3], :, :]
return out
def rgba_to_bgra(image: Tensor) -> Tensor:
# same operation as bgra_to_rgba(), flip image channels
return bgra_to_rgba(image)
def norm(x: Tensor):
"""Normalize (z-norm) from [0,1] range to [-1,1]"""
out = (x - 0.5) * 2.0
return out.clamp(-1, 1)
def np2tensor(
img: np.ndarray,
bgr2rgb=True,
data_range=1.0, # pylint: disable=unused-argument
normalize=False,
change_range=True,
add_batch=True,
) -> Tensor:
"""Converts a numpy image array into a Tensor array.
Parameters:
img (numpy array): the input image numpy array
add_batch (bool): choose if new tensor needs batch dimension added
"""
# check how many channels the image has, then condition. ie. RGB, RGBA, Gray
# if bgr2rgb:
# img = img[
# :, :, [2, 1, 0]
# ] # BGR to RGB -> in numpy, if using OpenCV, else not needed. Only if image has colors.
if change_range:
dtype = img.dtype
maxval = MAX_VALUES_BY_DTYPE.get(dtype.name, 1.0)
t_dtype = np.dtype("float32")
img = img.astype(t_dtype) / maxval # ie: uint8 = /255
# "HWC to CHW" and "numpy to tensor"
tensor = torch.from_numpy(
np.ascontiguousarray(np.transpose(as_3d(img), (2, 0, 1)))
).float()
if bgr2rgb:
# BGR to RGB -> in tensor, if using OpenCV, else not needed. Only if image has colors.)
if tensor.shape[0] % 3 == 0:
# RGB or MultixRGB (3xRGB, 5xRGB, etc. For video tensors.)
tensor = bgr_to_rgb(tensor)
elif tensor.shape[0] == 4:
# RGBA
tensor = bgra_to_rgba(tensor)
if add_batch:
# Add fake batch dimension = 1 . squeeze() will remove the dimensions of size 1
tensor.unsqueeze_(0)
if normalize:
tensor = norm(tensor)
return tensor
def tensor2np(
img: Tensor,
rgb2bgr=True,
remove_batch=True,
data_range=255,
denormalize=False,
change_range=True,
imtype: Type = np.uint8,
) -> np.ndarray:
"""Converts a Tensor array into a numpy image array.
Parameters:
img (tensor): the input image tensor array
4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
remove_batch (bool): choose if tensor of shape BCHW needs to be squeezed
denormalize (bool): Used to denormalize from [-1,1] range back to [0,1]
imtype (type): the desired type of the converted numpy array (np.uint8
default)
Output:
img (np array): 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
"""
n_dim = img.dim()
img = img.float().cpu()
img_np: np.ndarray
if n_dim in (4, 3):
# if n_dim == 4, has to convert to 3 dimensions
if n_dim == 4 and remove_batch:
# remove a fake batch dimension
img = img.squeeze(dim=0)
if img.shape[0] == 3 and rgb2bgr: # RGB
# RGB to BGR -> in tensor, if using OpenCV, else not needed. Only if image has colors.
img_np = rgb_to_bgr(img).numpy()
elif img.shape[0] == 4 and rgb2bgr: # RGBA
# RGBA to BGRA -> in tensor, if using OpenCV, else not needed. Only if image has colors.
img_np = rgba_to_bgra(img).numpy()
else:
img_np = img.numpy()
img_np = np.transpose(img_np, (1, 2, 0)) # CHW to HWC
elif n_dim == 2:
img_np = img.numpy()
else:
raise TypeError(
f"Only support 4D, 3D and 2D tensor. But received with dimension: {n_dim:d}"
)
# if rgb2bgr:
# img_np = img_np[[2, 1, 0], :, :] #RGB to BGR -> in numpy, if using OpenCV, else not needed. Only if image has colors.
if denormalize:
img_np = np_denorm(img_np) # denormalize if needed
if change_range:
img_np = np.clip(
data_range * img_np, 0, data_range
).round() # np.clip to the data_range
# has to be in range (0,255) before changing to np.uint8, else np.float32
return img_np.astype(imtype)
def safe_cuda_cache_empty():
"""
Empties the CUDA cache if CUDA is available. Hopefully without causing any errors.
"""
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
except:
pass