Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
from typing import TYPE_CHECKING | |
import numpy as np | |
from ..upscale.auto_split import Split, Tiler, auto_split | |
from .utils import np2tensor, safe_cuda_cache_empty, tensor2np | |
if TYPE_CHECKING: | |
from nodes.impl.pytorch.types import PyTorchModel | |
import torch | |
def pytorch_auto_split(img: np.ndarray, model: PyTorchModel, device: torch.device, use_fp16: bool, tiler: Tiler) -> np.ndarray: | |
model = model.to(device) | |
if use_fp16: | |
model = model.half() | |
# model = model.half() if use_fp16 else model.float() | |
def upscale(img: np.ndarray, _): | |
img_tensor = np2tensor(img, change_range=True) | |
d_img = None | |
try: | |
d_img = img_tensor.to(device) | |
d_img = d_img.half() if use_fp16 else d_img.float() | |
result = model(d_img) | |
result = tensor2np(result.detach().cpu().detach(), change_range=False, imtype=np.float32) | |
del d_img | |
return result | |
except RuntimeError as e: | |
# Check to see if its actually the CUDA out of memory error | |
if "allocate" in str(e) or "CUDA" in str(e): | |
# Collect garbage (clear VRAM) | |
if d_img is not None: | |
try: | |
d_img.detach().cpu() | |
except: | |
pass | |
del d_img | |
safe_cuda_cache_empty() | |
return Split() | |
else: | |
# Re-raise the exception if not an OOM error | |
raise | |
try: | |
return auto_split(img, upscale, tiler) | |
finally: | |
del model | |
del device | |
safe_cuda_cache_empty() | |