File size: 3,306 Bytes
319886d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from .blur import *
from .brightness import *
from .quantization import *
from .compression import *
from .contrast import *
from .noise import *
from .oversharpen import *
from .pixelate import *
from .saturate import *


def add_distortion(img, severity=1, distortion_name=None):
    """This function returns a distorted version of the given image.

    @param img (np.ndarray, unit8): Input image, H x W x 3, RGB, [0, 255]
    @param severity: Severity of distortion, [1, 5]
    @distortion_name: 
    @return: Degraded image (np.ndarray, unit8), H x W x 3, RGB, [0, 255]
    """

    if not isinstance(img, np.ndarray):
        raise AttributeError('Expecting type(img) to be numpy.ndarray')
    if not (img.dtype.type is np.uint8):
        raise AttributeError('Expecting img.dtype.type to be numpy.uint8')

    if not (img.ndim in [2, 3]):
        raise AttributeError('Expecting img.shape to be either (h x w) or (h x w x c)')
    if img.ndim == 2:
        img = np.stack((img,) * 3, axis=-1)

    h, w, c = img.shape
    if (h < 32 or w < 32):
        raise AttributeError('The (w, h) must be at least 32 pixels')
    if not (c in [1, 3]):
        raise AttributeError('Expecting img to have either 1 or 3 chennels')
    if c == 1:
        img = np.stack((np.squeeze(img),) * 3, axis=-1)

    if severity not in [1, 2, 3, 4, 5]:
        raise AttributeError('The severity must be an integer in [1, 5]')

    if distortion_name:
        img_lq = globals()[distortion_name](img, severity)
    else:
        raise ValueError("The distortion_name must be passed")

    return np.uint8(img_lq)


distortions_dict = {
    "blur": [
        "blur_gaussian", 
        "blur_motion", 
        "blur_glass", 
        "blur_lens", 
        "blur_zoom", 
        "blur_jitter", 
    ],
    "noise": [
        "noise_gaussian_RGB", 
        "noise_gaussian_YCrCb", 
        "noise_speckle", 
        "noise_spatially_correlated", 
        "noise_poisson", 
        "noise_impulse", 
    ], 
    "compression": [
        "compression_jpeg", 
        "compression_jpeg_2000", 
    ], 
    "brighten": [
        "brightness_brighten_shfit_HSV", 
        "brightness_brighten_shfit_RGB", 
        "brightness_brighten_gamma_HSV", 
        "brightness_brighten_gamma_RGB", 
    ], 
    "darken": [
        "brightness_darken_shfit_HSV", 
        "brightness_darken_shfit_RGB", 
        "brightness_darken_gamma_HSV", 
        "brightness_darken_gamma_RGB", 
    ], 
    "contrast_strengthen": [
        "contrast_strengthen_scale",
        "contrast_strengthen_stretch", 
    ],
    "contrast_weaken": [
        "contrast_weaken_scale",
        "contrast_weaken_stretch",
    ],
    "saturate_strengthen": [
        "saturate_strengthen_HSV", 
        "saturate_strengthen_YCrCb", 
    ], 
    "saturate_weaken": [
        "saturate_weaken_HSV", 
        "saturate_weaken_YCrCb", 
    ], 
    "oversharpen": [
        "oversharpen", 
    ], 
    "pixelate": [
        "pixelate", 
    ], 
    "quantization": [
        "quantization_otsu", 
        "quantization_median", 
        "quantization_hist", 
    ], 
    "spatter": [
        "spatter", 
    ], 
}


def get_distortion_names(subset=None):
    if subset in distortions_dict:
        print(distortions_dict[subset])
    else:
        print(distortions_dict)