File size: 3,226 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from typing import Tuple

import numpy as np

from ...utils.utils import get_h_w_c
from ..image_op import ImageOp, clipped
from ..image_utils import as_target_channels


def with_black_and_white_backgrounds(img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    c = get_h_w_c(img)[2]
    assert c == 4

    black = np.copy(img[:, :, :3])
    white = np.copy(img[:, :, :3])
    for c in range(3):
        black[:, :, c] *= img[:, :, 3]
        white[:, :, c] = (white[:, :, c] - 1) * img[:, :, 3] + 1

    return black, white


def denoise_and_flatten_alpha(img: np.ndarray) -> np.ndarray:
    alpha_min = np.min(img, axis=2)
    alpha_max = np.max(img, axis=2)
    alpha_mean = np.mean(img, axis=2)
    alpha = alpha_max * alpha_mean + alpha_min * (1 - alpha_mean)
    return alpha


def convenient_upscale(
    img: np.ndarray,
    model_in_nc: int,
    model_out_nc: int,
    upscale: ImageOp,
    separate_alpha: bool = False,
) -> np.ndarray:
    """
    Upscales the given image in an intuitive/convenient way.

    This method guarantees that the `upscale` function will be called with an image with
    `model_in_nc` number of channels.

    Additionally, guarantees that the number of channels of the output image will match
    that of the input image in cases where `model_in_nc` == `model_out_nc`, and match
    `model_out_nc` otherwise.
    """
    in_img_c = get_h_w_c(img)[2]

    upscale = clipped(upscale)

    if model_in_nc != model_out_nc:
        return upscale(as_target_channels(img, model_in_nc, True))

    if in_img_c == model_in_nc:
        return upscale(img)

    if in_img_c == 4:
        # Ignore alpha if single-color or not being replaced
        unique = np.unique(img[:, :, 3])
        if len(unique) == 1:
            rgb = as_target_channels(
                upscale(as_target_channels(img[:, :, :3], model_in_nc, True)), 3, True
            )
            unique_alpha = np.full(rgb.shape[:-1], unique[0], np.float32)
            return np.dstack((rgb, unique_alpha))

        if separate_alpha:
            # Upscale the RGB channels and alpha channel separately
            rgb = as_target_channels(
                upscale(as_target_channels(img[:, :, :3], model_in_nc, True)), 3, True
            )
            alpha = denoise_and_flatten_alpha(
                upscale(as_target_channels(img[:, :, 3], model_in_nc, True))
            )
            return np.dstack((rgb, alpha))
        else:
            # Transparency hack (white/black background difference alpha)
            black, white = with_black_and_white_backgrounds(img)
            black_up = as_target_channels(
                upscale(as_target_channels(black, model_in_nc, True)), 3, True
            )
            white_up = as_target_channels(
                upscale(as_target_channels(white, model_in_nc, True)), 3, True
            )

            # Interpolate between the alpha values to get a more defined alpha
            alpha_candidates = 1 - (white_up - black_up)  #  type: ignore
            alpha = denoise_and_flatten_alpha(alpha_candidates)

            return np.dstack((black_up, alpha))

    return as_target_channels(
        upscale(as_target_channels(img, model_in_nc, True)), in_img_c, True
    )