bilegentile's picture
Upload folder using huggingface_hub
c19ca42 verified
raw
history blame contribute delete
5.87 kB
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple
import numpy as np
from nodes.log import logger
from ...utils.utils import Padding, Region, Size, get_h_w_c
from ..image_utils import BorderType, create_border
def _pad_image(img: np.ndarray, min_size: Size):
h, w, _ = get_h_w_c(img)
min_w, min_h = min_size
x = max(0, min_w - w) / 2
y = max(0, min_h - h) / 2
padding = Padding(math.floor(y), math.floor(x), math.ceil(y), math.ceil(x))
return create_border(img, BorderType.REFLECT_MIRROR, padding), padding
@dataclass
class _Segment:
start: int
end: int
startPadding: int
endPadding: int
@property
def length(self) -> int:
return self.end - self.start
@property
def padded_length(self) -> int:
return self.end + self.endPadding - (self.start - self.startPadding)
def _exact_split_into_segments(length: int, exact: int, overlap: int) -> List[_Segment]:
"""
Splits the given length into segments of `exact` (padded) length.
Segments will overlap into each other with at least the given overlap.
"""
if length == exact:
# trivial
return [_Segment(0, exact, 0, 0)]
assert length > exact
assert exact > overlap * 2
result: List[_Segment] = []
def add(s: _Segment):
assert s.padded_length == exact
result.append(s)
# The current strategy is to go from left to right and to align segments
# such that we use the least overlap possible. The last segment will then
# be the smallest with potentially a lot of overlap.
# While this is easy to implement, it's actually not ideal. Ideally, we
# would want for the overlap to be distributed evenly between segments.
# However, this is complex to implement and the current method also works.
# we know that the first segment looks like this
add(_Segment(0, exact - overlap, 0, overlap))
while result[-1].end < length:
startPadding = overlap
start = result[-1].end
end = start + exact - overlap * 2
endPadding = overlap
if end + endPadding >= length:
# last segment
endPadding = 0
end = length
startPadding = exact - (end - start)
add(_Segment(start, end, startPadding, endPadding))
return result
def _exact_split_into_regions(
w: int,
h: int,
exact_w: int,
exact_h: int,
overlap: int,
) -> List[Tuple[Region, Padding]]:
"""
Returns a list of disjoint regions along with padding.
Each region plus its padding is guaranteed to have the given exact size.
The padding (if not zero) is guaranteed to be at least the given overlap value.
"""
# we can split x and y independently from each other and then combine the results
x_segments = _exact_split_into_segments(w, exact_w, overlap)
y_segments = _exact_split_into_segments(h, exact_h, overlap)
logger.info(f"chaiNNer: image is split into {len(x_segments)}x{len(y_segments)} tiles each exactly {exact_w}x{exact_h}px")
result: List[Tuple[Region, Padding]] = []
for y in y_segments:
for x in x_segments:
result.append(
(
Region(x.start, y.start, x.length, y.length),
Padding(y.startPadding, x.endPadding, y.endPadding, x.startPadding),
)
)
return result
def _exact_split_without_padding(
img: np.ndarray,
exact_size: Size,
upscale: Callable[[np.ndarray, Region], np.ndarray],
overlap: int,
) -> np.ndarray:
h, w, c = get_h_w_c(img)
exact_w, exact_h = exact_size
assert w >= exact_w and h >= exact_h
if (w, h) == exact_size:
return upscale(img, Region(0, 0, w, h))
# To allocate the result image, we need to know the upscale factor first,
# and we only get to know this factor after the first successful upscale.
result: Optional[np.ndarray] = None
scale: int = 0
regions = _exact_split_into_regions(w, h, exact_w, exact_h, overlap)
for tile, pad in regions:
padded_tile = tile.add_padding(pad)
upscale_result = upscale(padded_tile.read_from(img), padded_tile)
# figure out by how much the image was upscaled by
up_h, up_w, _ = get_h_w_c(upscale_result)
current_scale = up_h // padded_tile.height
assert current_scale > 0
assert padded_tile.height * current_scale == up_h
assert padded_tile.width * current_scale == up_w
if result is None:
# allocate the result image
scale = current_scale
result = np.zeros((h * scale, w * scale, c), dtype=np.float32)
assert current_scale == scale
# remove overlap padding
upscale_result = pad.scale(scale).remove_from(upscale_result)
# copy into result image
tile.scale(scale).write_into(result, upscale_result)
assert result is not None
# remove initially added padding
return result
def exact_split(
img: np.ndarray,
exact_size: Size,
upscale: Callable[[np.ndarray, Region], np.ndarray],
overlap: int = 16,
) -> np.ndarray:
"""
Splits the image into tiles with exactly the given tile size.
If the image is smaller than the given size, then it will be padded.
"""
# ensure that the image is at least as large as the given size
img, base_padding = _pad_image(img, exact_size)
h, w, _ = get_h_w_c(img)
result = _exact_split_without_padding(img, exact_size, upscale, overlap)
scale = get_h_w_c(result)[0] // h
if base_padding.empty:
return result
# remove initially added padding
return (
Region(0, 0, w, h).remove_padding(base_padding).scale(scale).read_from(result)
)