bilegentile's picture
Upload folder using huggingface_hub
c19ca42 verified
raw
history blame contribute delete
5.4 kB
import warnings
import numpy as np
def sanity_check_image(image):
"""Performs a sanity check for input images. Image values should be in the
range [0, 1], the `dtype` should be `np.float32` or `np.float64` and the
image shape should be `(?, ?, 3)`.
Parameters
----------
image: numpy.ndarray
Image with shape :math:`h \\times w \\times 3`
Example
-------
>>> import numpy as np
>>> from pymatting import check_image
>>> image = (np.random.randn(64, 64, 2) * 255).astype(np.int32)
>>> sanity_check_image(image)
__main__:1: UserWarning: Expected RGB image of shape (?, ?, 3), but image.shape is (64, 64, 2).
__main__:1: UserWarning: Image values should be in [0, 1], but image.min() is -933.
__main__:1: UserWarning: Image values should be in [0, 1], but image.max() is 999.
__main__:1: UserWarning: Unexpected image.dtype int32. Are you sure that you do not want to use np.float32 or np.float64 instead?
"""
if len(image.shape) != 3 or image.shape[2] != 3:
warnings.warn(
"Expected RGB image of shape (?, ?, 3), but image.shape is %s."
% str(image.shape),
stacklevel=3,
)
min_value = image.min()
max_value = image.max()
if min_value < 0.0:
warnings.warn(
"Image values should be in [0, 1], but image.min() is %s." % min_value,
stacklevel=3,
)
if max_value > 1.0:
warnings.warn(
"Image values should be in [0, 1], but image.max() is %s." % max_value,
stacklevel=3,
)
if image.dtype not in [np.float32, np.float64]:
warnings.warn(
"Unexpected image.dtype %s. Are you sure that you do not want to use np.float32 or np.float64 instead?"
% image.dtype,
stacklevel=3,
)
def stack_images(*images):
"""This function stacks images along the third axis.
This is useful for combining e.g. rgb color channels or color and alpha channels.
Parameters
----------
*images: numpy.ndarray
Images to be stacked.
Returns
-------
image: numpy.ndarray
Stacked images as numpy.ndarray
Example
-------
>>> from pymatting.util.util import stack_images
>>> import numpy as np
>>> I = stack_images(np.random.rand(4,5,3), np.random.rand(4,5,3))
>>> I.shape
(4, 5, 6)
"""
images = [
(image if len(image.shape) == 3 else image[:, :, np.newaxis])
for image in images
]
return np.concatenate(images, axis=2)
def trimap_split(trimap, flatten=True, bg_threshold=0.1, fg_threshold=0.9):
"""This function splits the trimap into foreground pixels, background pixels, and unknown pixels.
Foreground pixels are pixels where the trimap has values larger than or equal to `fg_threshold` (default: 0.9).
Background pixels are pixels where the trimap has values smaller than or equal to `bg_threshold` (default: 0.1).
Pixels with other values are assumed to be unknown.
Parameters
----------
trimap: numpy.ndarray
Trimap with shape :math:`h \\times w`
flatten: bool
If true np.flatten is called on the trimap
Returns
-------
is_fg: numpy.ndarray
Boolean array indicating which pixel belongs to the foreground
is_bg: numpy.ndarray
Boolean array indicating which pixel belongs to the background
is_known: numpy.ndarray
Boolean array indicating which pixel is known
is_unknown: numpy.ndarray
Boolean array indicating which pixel is unknown
bg_threshold: float
Pixels with smaller trimap values will be considered background.
fg_threshold: float
Pixels with larger trimap values will be considered foreground.
Example
-------
>>> import numpy as np
>>> from pymatting import *
>>> trimap = np.array([[1,0],[0.5,0.2]])
>>> is_fg, is_bg, is_known, is_unknown = trimap_split(trimap)
>>> is_fg
array([ True, False, False, False])
>>> is_bg
array([False, True, False, False])
>>> is_known
array([ True, True, False, False])
>>> is_unknown
array([False, False, True, True])
"""
if flatten:
trimap = trimap.flatten()
min_value = trimap.min()
max_value = trimap.max()
if min_value < 0.0:
warnings.warn(
"Trimap values should be in [0, 1], but trimap.min() is %s." % min_value,
stacklevel=3,
)
if max_value > 1.0:
warnings.warn(
"Trimap values should be in [0, 1], but trimap.max() is %s." % min_value,
stacklevel=3,
)
if trimap.dtype not in [np.float32, np.float64]:
warnings.warn(
"Unexpected trimap.dtype %s. Are you sure that you do not want to use np.float32 or np.float64 instead?"
% trimap.dtype,
stacklevel=3,
)
is_fg = trimap >= fg_threshold
is_bg = trimap <= bg_threshold
if is_bg.sum() == 0:
raise ValueError(
"Trimap did not contain background values (values <= %f)" % bg_threshold
)
if is_fg.sum() == 0:
raise ValueError(
"Trimap did not contain foreground values (values >= %f)" % fg_threshold
)
is_known = is_fg | is_bg
is_unknown = ~is_known
return is_fg, is_bg, is_known, is_unknown