File size: 1,359 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from typing import Callable, NewType

import numpy as np
from nodes.log import logger

from ...utils.utils import get_h_w_c
from .tiler import MaxTileSize, NoTiling, Tiler


def estimate_tile_size(
    budget: int,
    model_size: int,
    img: np.ndarray,
    img_element_size: int = 4,
) -> int:
    h, w, c = get_h_w_c(img)
    img_bytes = h * w * c * img_element_size
    mem_required_estimation = (model_size / (1024 * 52)) * img_bytes

    tile_pixels = w * h * budget / mem_required_estimation
    # the largest power-of-2 tile_size such that tile_size**2 < tile_pixels
    tile_size = 2 ** (int(tile_pixels**0.5).bit_length() - 1)

    GB_AMT = 1024**3
    required_mem = f"{mem_required_estimation/GB_AMT:.2f}"
    budget_mem = f"{budget/GB_AMT:.2f}"
    logger.info(f"chaiNNer: estimating memory required: {required_mem} GB, {budget_mem} GB free Estimated tile size: {tile_size}")
    return tile_size


TileSize = NewType("TileSize", int)
ESTIMATE = TileSize(0)
NO_TILING = TileSize(-1)
MAX_TILE_SIZE = TileSize(-2)
TILE_SIZE_256 = TileSize(256)


def parse_tile_size_input(tile_size: TileSize, estimate: Callable[[], Tiler]) -> Tiler:
    if tile_size == 0:
        return estimate()
    if tile_size == -1:
        return NoTiling()
    if tile_size == -2:
        return MaxTileSize()

    assert tile_size > 0
    return MaxTileSize(tile_size)