from __future__ import annotations import gc import numpy as np from ..upscale.auto_split import Tiler, auto_split from .np_tensor_utils import np2nptensor, nptensor2np def onnx_auto_split( img: np.ndarray, session, change_shape: bool, tiler: Tiler, ) -> np.ndarray: input_name = session.get_inputs()[0].name output_name = session.get_outputs()[0].name is_fp16_model = session.get_inputs()[0].type == "tensor(float16)" def upscale(img: np.ndarray, _): try: lr_img = np2nptensor(img, change_range=False) lr_img = lr_img.astype(np.float16) if is_fp16_model else lr_img if change_shape: # Transpose from BCHW to BHWC lr_img = np.transpose(lr_img, (0, 2, 3, 1)) output: np.ndarray = session.run([output_name], {input_name: lr_img})[0] if change_shape: # Transpose back to BCHW output = np.transpose(output, (0, 3, 1, 2)) return nptensor2np(output, change_range=False, imtype=np.float32) except Exception as e: if "ONNXRuntimeError" in str(e) and ( "allocate memory" in str(e) or "out of memory" in str(e) or "cudaMalloc" in str(e) ): # pylint: disable=raise-missing-from raise RuntimeError( "A VRAM out-of-memory error has occurred. Please try using a more extreme tiling mode." ) else: # Re-raise the exception if not an OOM error raise try: return auto_split(img, upscale, tiler) finally: gc.collect()