File size: 1,701 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from __future__ import annotations
import gc
import numpy as np
from ..upscale.auto_split import Tiler, auto_split
from .np_tensor_utils import np2nptensor, nptensor2np


def onnx_auto_split(
    img: np.ndarray,
    session,
    change_shape: bool,
    tiler: Tiler,
) -> np.ndarray:
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name

    is_fp16_model = session.get_inputs()[0].type == "tensor(float16)"

    def upscale(img: np.ndarray, _):
        try:
            lr_img = np2nptensor(img, change_range=False)
            lr_img = lr_img.astype(np.float16) if is_fp16_model else lr_img

            if change_shape:
                # Transpose from BCHW to BHWC
                lr_img = np.transpose(lr_img, (0, 2, 3, 1))

            output: np.ndarray = session.run([output_name], {input_name: lr_img})[0]

            if change_shape:
                # Transpose back to BCHW
                output = np.transpose(output, (0, 3, 1, 2))

            return nptensor2np(output, change_range=False, imtype=np.float32)
        except Exception as e:
            if "ONNXRuntimeError" in str(e) and (
                "allocate memory" in str(e)
                or "out of memory" in str(e)
                or "cudaMalloc" in str(e)
            ):
                # pylint: disable=raise-missing-from
                raise RuntimeError(
                    "A VRAM out-of-memory error has occurred. Please try using a more extreme tiling mode."
                )
            else:
                # Re-raise the exception if not an OOM error
                raise

    try:
        return auto_split(img, upscale, tiler)
    finally:
        gc.collect()