Spaces:
Runtime error
Runtime error
from typing import Callable, NewType | |
import numpy as np | |
from nodes.log import logger | |
from ...utils.utils import get_h_w_c | |
from .tiler import MaxTileSize, NoTiling, Tiler | |
def estimate_tile_size( | |
budget: int, | |
model_size: int, | |
img: np.ndarray, | |
img_element_size: int = 4, | |
) -> int: | |
h, w, c = get_h_w_c(img) | |
img_bytes = h * w * c * img_element_size | |
mem_required_estimation = (model_size / (1024 * 52)) * img_bytes | |
tile_pixels = w * h * budget / mem_required_estimation | |
# the largest power-of-2 tile_size such that tile_size**2 < tile_pixels | |
tile_size = 2 ** (int(tile_pixels**0.5).bit_length() - 1) | |
GB_AMT = 1024**3 | |
required_mem = f"{mem_required_estimation/GB_AMT:.2f}" | |
budget_mem = f"{budget/GB_AMT:.2f}" | |
logger.info(f"chaiNNer: estimating memory required: {required_mem} GB, {budget_mem} GB free Estimated tile size: {tile_size}") | |
return tile_size | |
TileSize = NewType("TileSize", int) | |
ESTIMATE = TileSize(0) | |
NO_TILING = TileSize(-1) | |
MAX_TILE_SIZE = TileSize(-2) | |
TILE_SIZE_256 = TileSize(256) | |
def parse_tile_size_input(tile_size: TileSize, estimate: Callable[[], Tiler]) -> Tiler: | |
if tile_size == 0: | |
return estimate() | |
if tile_size == -1: | |
return NoTiling() | |
if tile_size == -2: | |
return MaxTileSize() | |
assert tile_size > 0 | |
return MaxTileSize(tile_size) | |