File size: 7,632 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
from __future__ import annotations
import math
from typing import Callable, Optional, Union
import numpy as np
from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn, TimeElapsedColumn
from nodes.log import logger, console
from ...utils.utils import Region, Size, get_h_w_c
from .exact_split import exact_split
from .tiler import Tiler


class Split:
    pass


SplitImageOp = Callable[[np.ndarray, Region], Union[np.ndarray, Split]]


def auto_split(
    img: np.ndarray,
    upscale: SplitImageOp,
    tiler: Tiler,
    overlap: int = 16,
) -> np.ndarray:
    """
    Splits the image into tiles according to the given tiler.

    This method only changes the size of the given image, the tiles passed into the upscale function will have same number of channels.

    The region passed into the upscale function is the region of the current tile.
    The size of the region is guaranteed to be the same as the size of the given tile.

    ## Padding

    If the given tiler allows smaller tile sizes, then it is guaranteed that no padding will be added.
    Otherwise, no padding is only guaranteed if the starting tile size is not larger than the size of the given image.
    """

    h, w, c = get_h_w_c(img)
    split = _max_split if tiler.allow_smaller_tile_size() else _exact_split

    return split(
        img,
        upscale=upscale,
        starting_tile_size=tiler.starting_tile_size(w, h, c),
        split_tile_size=tiler.split,
        overlap=overlap,
    )


class _SplitEx(Exception):
    pass


def _exact_split(
    img: np.ndarray,
    upscale: SplitImageOp,
    starting_tile_size: Size,
    split_tile_size: Callable[[Size], Size],
    overlap: int,
) -> np.ndarray:
    h, w, c = get_h_w_c(img)
    logger.info(f"chaiNNer: exact size split image ({w}x{h}px @ {c}) with exact tile size {starting_tile_size[0]}x{starting_tile_size[1]}px.")

    def no_split_upscale(i: np.ndarray, r: Region) -> np.ndarray:
        result = upscale(i, r)
        if isinstance(result, Split):
            raise _SplitEx
        return result

    MAX_ITER = 20
    for _ in range(MAX_ITER):
        try:
            max_overlap = min(*starting_tile_size) // 4
            return exact_split(
                img=img,
                exact_size=starting_tile_size,
                upscale=no_split_upscale,
                overlap=min(max_overlap, overlap),
            )
        except _SplitEx:
            starting_tile_size = split_tile_size(starting_tile_size)

    raise ValueError(f"Aborting after {MAX_ITER} splits. Unable to upscale image.")


def _max_split(
    img: np.ndarray,
    upscale: SplitImageOp,
    starting_tile_size: Size,
    split_tile_size: Callable[[Size], Size],
    overlap: int,
) -> np.ndarray:
    """
    Splits the image into tiles with at most the given tile size.

    If the upscale method requests a split, then the tile size will be lowered.
    """

    h, w, c = get_h_w_c(img)

    img_region = Region(0, 0, w, h)

    max_tile_size = starting_tile_size
    # logger.debug(f"chaiNNer: auto split image ({w}x{h}px @ {c}) with initial tile size {max_tile_size}")

    if w <= max_tile_size[0] and h <= max_tile_size[1]:
        # the image might be small enough so that we don't have to split at all
        upscale_result = upscale(img, img_region)
        if not isinstance(upscale_result, Split):
            return upscale_result
        # the image was too large
        max_tile_size = split_tile_size(max_tile_size)
        logger.info(f"chaiNNer: unable to upscale the whole image at once. Reduced tile size to {max_tile_size}")

    # The upscale method is allowed to request splits at any time.
    # When a split occurs, we have to "restart" the loop and
    # these 2 variables allow us to split the already processed tiles.
    start_x = 0
    start_y = 0

    # To allocate the result image, we need to know the upscale factor first,
    # and we only get to know this factor after the first successful upscale.
    result: Optional[np.ndarray] = None
    scale: int = 0

    restart = True
    while restart:
        restart = False

        # This is a bit complex.
        # We don't actually use the current tile size to partition the image.
        # If we did, then tile_size=1024 and w=1200 would result in very uneven tiles.
        # Instead, we use tile_size to calculate how many tiles we get in the x and y direction
        # and then calculate the optimal tile size for the x and y direction using the counts.
        # This yields optimal tile sizes which should prevent unnecessary splitting.
        tile_count_x = math.ceil(w / max_tile_size[0])
        tile_count_y = math.ceil(h / max_tile_size[1])
        tile_size_x = math.ceil(w / tile_count_x)
        tile_size_y = math.ceil(h / tile_count_y)
        # logger.debug(f"chaiNNer: Currently {tile_count_x}x{tile_count_y} tiles each {tile_size_x}x{tile_size_y}px.")
        with Progress(TextColumn('[cyan]{task.description}'), BarColumn(), TaskProgressColumn(), TimeRemainingColumn(), TimeElapsedColumn(), console=console) as progress:
            task = progress.add_task(description="Upscaling", total=tile_count_y * tile_count_x)
            for y in range(0, tile_count_y):
                if restart:
                    break
                if y < start_y:
                    continue
                for x in range(0, tile_count_x):
                    if y == start_y and x < start_x:
                        continue
                    tile = Region(x * tile_size_x, y * tile_size_y, tile_size_x, tile_size_y).intersect(img_region)
                    pad = img_region.child_padding(tile).min(overlap)
                    padded_tile = tile.add_padding(pad)
                    upscale_result = upscale(padded_tile.read_from(img), padded_tile)
                    if isinstance(upscale_result, Split):
                        max_tile_size = split_tile_size(max_tile_size)
                        new_tile_count_x = math.ceil(w / max_tile_size[0])
                        new_tile_count_y = math.ceil(h / max_tile_size[1])
                        new_tile_size_x = math.ceil(w / new_tile_count_x)
                        new_tile_size_y = math.ceil(h / new_tile_count_y)
                        start_x = (x * tile_size_x) // new_tile_size_x
                        start_y = (y * tile_size_x) // new_tile_size_y
                        logger.debug(f"chaiNNer: Split occurred. New tile size is {max_tile_size}. Starting at {start_x},{start_y}.")
                        restart = True
                        break
                    # figure out by how much the image was upscaled by
                    up_h, up_w, _ = get_h_w_c(upscale_result)
                    current_scale = up_h // padded_tile.height
                    assert current_scale > 0
                    assert padded_tile.height * current_scale == up_h
                    assert padded_tile.width * current_scale == up_w
                    if result is None:
                        # allocate the result image
                        scale = current_scale
                        result = np.zeros((h * scale, w * scale, c), dtype=np.float32)
                    assert current_scale == scale
                    # remove overlap padding
                    upscale_result = pad.scale(scale).remove_from(upscale_result)
                    # copy into result image
                    tile.scale(scale).write_into(result, upscale_result)
                    progress.update(task, advance=1, description="Upscaling")

    assert result is not None
    return result