Spaces:
Running
Running
File size: 5,260 Bytes
966ae59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
r"""Colour space conversion functions"""
from typing import Union, Dict
import torch
def rgb2lmn(x: torch.Tensor) -> torch.Tensor:
r"""Convert a batch of RGB images to a batch of LMN images
Args:
x: Batch of images with shape (N, 3, H, W). RGB colour space.
Returns:
Batch of images with shape (N, 3, H, W). LMN colour space.
"""
weights_rgb_to_lmn = torch.tensor([[0.06, 0.63, 0.27],
[0.30, 0.04, -0.35],
[0.34, -0.6, 0.17]], dtype=x.dtype, device=x.device).t()
x_lmn = torch.matmul(x.permute(0, 2, 3, 1), weights_rgb_to_lmn).permute(0, 3, 1, 2)
return x_lmn
def rgb2xyz(x: torch.Tensor) -> torch.Tensor:
r"""Convert a batch of RGB images to a batch of XYZ images
Args:
x: Batch of images with shape (N, 3, H, W). RGB colour space.
Returns:
Batch of images with shape (N, 3, H, W). XYZ colour space.
"""
mask_below = (x <= 0.04045).type(x.dtype)
mask_above = (x > 0.04045).type(x.dtype)
tmp = x / 12.92 * mask_below + torch.pow((x + 0.055) / 1.055, 2.4) * mask_above
weights_rgb_to_xyz = torch.tensor([[0.4124564, 0.3575761, 0.1804375],
[0.2126729, 0.7151522, 0.0721750],
[0.0193339, 0.1191920, 0.9503041]], dtype=x.dtype, device=x.device)
x_xyz = torch.matmul(tmp.permute(0, 2, 3, 1), weights_rgb_to_xyz.t()).permute(0, 3, 1, 2)
return x_xyz
def xyz2lab(x: torch.Tensor, illuminant: str = 'D50', observer: str = '2') -> torch.Tensor:
r"""Convert a batch of XYZ images to a batch of LAB images
Args:
x: Batch of images with shape (N, 3, H, W). XYZ colour space.
illuminant: {βAβ, βD50β, βD55β, βD65β, βD75β, βEβ}, optional. The name of the illuminant.
observer: {β2β, β10β}, optional. The aperture angle of the observer.
Returns:
Batch of images with shape (N, 3, H, W). LAB colour space.
"""
epsilon = 0.008856
kappa = 903.3
illuminants: Dict[str, Dict] = \
{"A": {'2': (1.098466069456375, 1, 0.3558228003436005),
'10': (1.111420406956693, 1, 0.3519978321919493)},
"D50": {'2': (0.9642119944211994, 1, 0.8251882845188288),
'10': (0.9672062750333777, 1, 0.8142801513128616)},
"D55": {'2': (0.956797052643698, 1, 0.9214805860173273),
'10': (0.9579665682254781, 1, 0.9092525159847462)},
"D65": {'2': (0.95047, 1., 1.08883), # This was: `lab_ref_white`
'10': (0.94809667673716, 1, 1.0730513595166162)},
"D75": {'2': (0.9497220898840717, 1, 1.226393520724154),
'10': (0.9441713925645873, 1, 1.2064272211720228)},
"E": {'2': (1.0, 1.0, 1.0),
'10': (1.0, 1.0, 1.0)}}
illuminants_to_use = torch.tensor(illuminants[illuminant][observer],
dtype=x.dtype, device=x.device).view(1, 3, 1, 1)
tmp = x / illuminants_to_use
mask_below = (tmp <= epsilon).type(x.dtype)
mask_above = (tmp > epsilon).type(x.dtype)
tmp = torch.pow(tmp, 1. / 3.) * mask_above + (kappa * tmp + 16.) / 116. * mask_below
weights_xyz_to_lab = torch.tensor([[0, 116., 0],
[500., -500., 0],
[0, 200., -200.]], dtype=x.dtype, device=x.device)
bias_xyz_to_lab = torch.tensor([-16., 0., 0.], dtype=x.dtype, device=x.device).view(1, 3, 1, 1)
x_lab = torch.matmul(tmp.permute(0, 2, 3, 1), weights_xyz_to_lab.t()).permute(0, 3, 1, 2) + bias_xyz_to_lab
return x_lab
def rgb2lab(x: torch.Tensor, data_range: Union[int, float] = 255) -> torch.Tensor:
r"""Convert a batch of RGB images to a batch of LAB images
Args:
x: Batch of images with shape (N, 3, H, W). RGB colour space.
data_range: dynamic range of the input image.
Returns:
Batch of images with shape (N, 3, H, W). LAB colour space.
"""
return xyz2lab(rgb2xyz(x / float(data_range)))
def rgb2yiq(x: torch.Tensor) -> torch.Tensor:
r"""Convert a batch of RGB images to a batch of YIQ images
Args:
x: Batch of images with shape (N, 3, H, W). RGB colour space.
Returns:
Batch of images with shape (N, 3, H, W). YIQ colour space.
"""
yiq_weights = torch.tensor([
[0.299, 0.587, 0.114],
[0.5959, -0.2746, -0.3213],
[0.2115, -0.5227, 0.3112]], dtype=x.dtype, device=x.device).t()
x_yiq = torch.matmul(x.permute(0, 2, 3, 1), yiq_weights).permute(0, 3, 1, 2)
return x_yiq
def rgb2lhm(x: torch.Tensor) -> torch.Tensor:
r"""Convert a batch of RGB images to a batch of LHM images
Args:
x: Batch of images with shape (N, 3, H, W). RGB colour space.
Returns:
Batch of images with shape (N, 3, H, W). LHM colour space.
Reference:
https://arxiv.org/pdf/1608.07433.pdf
"""
lhm_weights = torch.tensor([
[0.2989, 0.587, 0.114],
[0.3, 0.04, -0.35],
[0.34, -0.6, 0.17]], dtype=x.dtype, device=x.device).t()
x_lhm = torch.matmul(x.permute(0, 2, 3, 1), lhm_weights).permute(0, 3, 1, 2)
return x_lhm
|