File size: 5,396 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import warnings

import numpy as np


def sanity_check_image(image):
    """Performs a sanity check for input images. Image values should be in the
    range [0, 1], the `dtype` should be `np.float32` or `np.float64` and the
    image shape should be `(?, ?, 3)`.
    Parameters
    ----------
    image: numpy.ndarray
        Image with shape :math:`h \\times w \\times 3`
    Example
    -------
    >>> import numpy as np
    >>> from pymatting import check_image
    >>> image = (np.random.randn(64, 64, 2) * 255).astype(np.int32)
    >>> sanity_check_image(image)
    __main__:1: UserWarning: Expected RGB image of shape (?, ?, 3), but image.shape is (64, 64, 2).
    __main__:1: UserWarning: Image values should be in [0, 1], but image.min() is -933.
    __main__:1: UserWarning: Image values should be in [0, 1], but image.max() is 999.
    __main__:1: UserWarning: Unexpected image.dtype int32. Are you sure that you do not want to use np.float32 or np.float64 instead?
    """

    if len(image.shape) != 3 or image.shape[2] != 3:
        warnings.warn(
            "Expected RGB image of shape (?, ?, 3), but image.shape is %s."
            % str(image.shape),
            stacklevel=3,
        )

    min_value = image.min()
    max_value = image.max()

    if min_value < 0.0:
        warnings.warn(
            "Image values should be in [0, 1], but image.min() is %s." % min_value,
            stacklevel=3,
        )

    if max_value > 1.0:
        warnings.warn(
            "Image values should be in [0, 1], but image.max() is %s." % max_value,
            stacklevel=3,
        )

    if image.dtype not in [np.float32, np.float64]:
        warnings.warn(
            "Unexpected image.dtype %s. Are you sure that you do not want to use np.float32 or np.float64 instead?"
            % image.dtype,
            stacklevel=3,
        )


def stack_images(*images):
    """This function stacks images along the third axis.
    This is useful for combining e.g. rgb color channels or color and alpha channels.
    Parameters
    ----------
    *images: numpy.ndarray
        Images to be stacked.
    Returns
    -------
    image: numpy.ndarray
        Stacked images as numpy.ndarray
    Example
    -------
    >>> from pymatting.util.util import stack_images
    >>> import numpy as np
    >>> I = stack_images(np.random.rand(4,5,3), np.random.rand(4,5,3))
    >>> I.shape
    (4, 5, 6)
    """
    images = [
        (image if len(image.shape) == 3 else image[:, :, np.newaxis])
        for image in images
    ]
    return np.concatenate(images, axis=2)


def trimap_split(trimap, flatten=True, bg_threshold=0.1, fg_threshold=0.9):
    """This function splits the trimap into foreground pixels, background pixels, and unknown pixels.
    Foreground pixels are pixels where the trimap has values larger than or equal to `fg_threshold` (default: 0.9).
    Background pixels are pixels where the trimap has values smaller than or equal to `bg_threshold` (default: 0.1).
    Pixels with other values are assumed to be unknown.
    Parameters
    ----------
    trimap: numpy.ndarray
        Trimap with shape :math:`h \\times w`
    flatten: bool
        If true np.flatten is called on the trimap
    Returns
    -------
    is_fg: numpy.ndarray
        Boolean array indicating which pixel belongs to the foreground
    is_bg: numpy.ndarray
        Boolean array indicating which pixel belongs to the background
    is_known: numpy.ndarray
        Boolean array indicating which pixel is known
    is_unknown: numpy.ndarray
        Boolean array indicating which pixel is unknown
    bg_threshold: float
        Pixels with smaller trimap values will be considered background.
    fg_threshold: float
        Pixels with larger trimap values will be considered foreground.
    Example
    -------
    >>> import numpy as np
    >>> from pymatting import *
    >>> trimap = np.array([[1,0],[0.5,0.2]])
    >>> is_fg, is_bg, is_known, is_unknown = trimap_split(trimap)
    >>> is_fg
    array([ True, False, False, False])
    >>> is_bg
    array([False,  True, False, False])
    >>> is_known
    array([ True,  True, False, False])
    >>> is_unknown
    array([False, False,  True,  True])
    """
    if flatten:
        trimap = trimap.flatten()

    min_value = trimap.min()
    max_value = trimap.max()

    if min_value < 0.0:
        warnings.warn(
            "Trimap values should be in [0, 1], but trimap.min() is %s." % min_value,
            stacklevel=3,
        )

    if max_value > 1.0:
        warnings.warn(
            "Trimap values should be in [0, 1], but trimap.max() is %s." % min_value,
            stacklevel=3,
        )

    if trimap.dtype not in [np.float32, np.float64]:
        warnings.warn(
            "Unexpected trimap.dtype %s. Are you sure that you do not want to use np.float32 or np.float64 instead?"
            % trimap.dtype,
            stacklevel=3,
        )

    is_fg = trimap >= fg_threshold
    is_bg = trimap <= bg_threshold

    if is_bg.sum() == 0:
        raise ValueError(
            "Trimap did not contain background values (values <= %f)" % bg_threshold
        )

    if is_fg.sum() == 0:
        raise ValueError(
            "Trimap did not contain foreground values (values >= %f)" % fg_threshold
        )

    is_known = is_fg | is_bg
    is_unknown = ~is_known

    return is_fg, is_bg, is_known, is_unknown