bilegentile's picture
Upload folder using huggingface_hub
c19ca42 verified
raw
history blame contribute delete
7.63 kB
from __future__ import annotations
import math
from typing import Callable, Optional, Union
import numpy as np
from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn, TimeElapsedColumn
from nodes.log import logger, console
from ...utils.utils import Region, Size, get_h_w_c
from .exact_split import exact_split
from .tiler import Tiler
class Split:
pass
SplitImageOp = Callable[[np.ndarray, Region], Union[np.ndarray, Split]]
def auto_split(
img: np.ndarray,
upscale: SplitImageOp,
tiler: Tiler,
overlap: int = 16,
) -> np.ndarray:
"""
Splits the image into tiles according to the given tiler.
This method only changes the size of the given image, the tiles passed into the upscale function will have same number of channels.
The region passed into the upscale function is the region of the current tile.
The size of the region is guaranteed to be the same as the size of the given tile.
## Padding
If the given tiler allows smaller tile sizes, then it is guaranteed that no padding will be added.
Otherwise, no padding is only guaranteed if the starting tile size is not larger than the size of the given image.
"""
h, w, c = get_h_w_c(img)
split = _max_split if tiler.allow_smaller_tile_size() else _exact_split
return split(
img,
upscale=upscale,
starting_tile_size=tiler.starting_tile_size(w, h, c),
split_tile_size=tiler.split,
overlap=overlap,
)
class _SplitEx(Exception):
pass
def _exact_split(
img: np.ndarray,
upscale: SplitImageOp,
starting_tile_size: Size,
split_tile_size: Callable[[Size], Size],
overlap: int,
) -> np.ndarray:
h, w, c = get_h_w_c(img)
logger.info(f"chaiNNer: exact size split image ({w}x{h}px @ {c}) with exact tile size {starting_tile_size[0]}x{starting_tile_size[1]}px.")
def no_split_upscale(i: np.ndarray, r: Region) -> np.ndarray:
result = upscale(i, r)
if isinstance(result, Split):
raise _SplitEx
return result
MAX_ITER = 20
for _ in range(MAX_ITER):
try:
max_overlap = min(*starting_tile_size) // 4
return exact_split(
img=img,
exact_size=starting_tile_size,
upscale=no_split_upscale,
overlap=min(max_overlap, overlap),
)
except _SplitEx:
starting_tile_size = split_tile_size(starting_tile_size)
raise ValueError(f"Aborting after {MAX_ITER} splits. Unable to upscale image.")
def _max_split(
img: np.ndarray,
upscale: SplitImageOp,
starting_tile_size: Size,
split_tile_size: Callable[[Size], Size],
overlap: int,
) -> np.ndarray:
"""
Splits the image into tiles with at most the given tile size.
If the upscale method requests a split, then the tile size will be lowered.
"""
h, w, c = get_h_w_c(img)
img_region = Region(0, 0, w, h)
max_tile_size = starting_tile_size
# logger.debug(f"chaiNNer: auto split image ({w}x{h}px @ {c}) with initial tile size {max_tile_size}")
if w <= max_tile_size[0] and h <= max_tile_size[1]:
# the image might be small enough so that we don't have to split at all
upscale_result = upscale(img, img_region)
if not isinstance(upscale_result, Split):
return upscale_result
# the image was too large
max_tile_size = split_tile_size(max_tile_size)
logger.info(f"chaiNNer: unable to upscale the whole image at once. Reduced tile size to {max_tile_size}")
# The upscale method is allowed to request splits at any time.
# When a split occurs, we have to "restart" the loop and
# these 2 variables allow us to split the already processed tiles.
start_x = 0
start_y = 0
# To allocate the result image, we need to know the upscale factor first,
# and we only get to know this factor after the first successful upscale.
result: Optional[np.ndarray] = None
scale: int = 0
restart = True
while restart:
restart = False
# This is a bit complex.
# We don't actually use the current tile size to partition the image.
# If we did, then tile_size=1024 and w=1200 would result in very uneven tiles.
# Instead, we use tile_size to calculate how many tiles we get in the x and y direction
# and then calculate the optimal tile size for the x and y direction using the counts.
# This yields optimal tile sizes which should prevent unnecessary splitting.
tile_count_x = math.ceil(w / max_tile_size[0])
tile_count_y = math.ceil(h / max_tile_size[1])
tile_size_x = math.ceil(w / tile_count_x)
tile_size_y = math.ceil(h / tile_count_y)
# logger.debug(f"chaiNNer: Currently {tile_count_x}x{tile_count_y} tiles each {tile_size_x}x{tile_size_y}px.")
with Progress(TextColumn('[cyan]{task.description}'), BarColumn(), TaskProgressColumn(), TimeRemainingColumn(), TimeElapsedColumn(), console=console) as progress:
task = progress.add_task(description="Upscaling", total=tile_count_y * tile_count_x)
for y in range(0, tile_count_y):
if restart:
break
if y < start_y:
continue
for x in range(0, tile_count_x):
if y == start_y and x < start_x:
continue
tile = Region(x * tile_size_x, y * tile_size_y, tile_size_x, tile_size_y).intersect(img_region)
pad = img_region.child_padding(tile).min(overlap)
padded_tile = tile.add_padding(pad)
upscale_result = upscale(padded_tile.read_from(img), padded_tile)
if isinstance(upscale_result, Split):
max_tile_size = split_tile_size(max_tile_size)
new_tile_count_x = math.ceil(w / max_tile_size[0])
new_tile_count_y = math.ceil(h / max_tile_size[1])
new_tile_size_x = math.ceil(w / new_tile_count_x)
new_tile_size_y = math.ceil(h / new_tile_count_y)
start_x = (x * tile_size_x) // new_tile_size_x
start_y = (y * tile_size_x) // new_tile_size_y
logger.debug(f"chaiNNer: Split occurred. New tile size is {max_tile_size}. Starting at {start_x},{start_y}.")
restart = True
break
# figure out by how much the image was upscaled by
up_h, up_w, _ = get_h_w_c(upscale_result)
current_scale = up_h // padded_tile.height
assert current_scale > 0
assert padded_tile.height * current_scale == up_h
assert padded_tile.width * current_scale == up_w
if result is None:
# allocate the result image
scale = current_scale
result = np.zeros((h * scale, w * scale, c), dtype=np.float32)
assert current_scale == scale
# remove overlap padding
upscale_result = pad.scale(scale).remove_from(upscale_result)
# copy into result image
tile.scale(scale).write_into(result, upscale_result)
progress.update(task, advance=1, description="Upscaling")
assert result is not None
return result