File size: 1,693 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from __future__ import annotations
from typing import TYPE_CHECKING
import numpy as np
from ..upscale.auto_split import Split, Tiler, auto_split
from .utils import np2tensor, safe_cuda_cache_empty, tensor2np
if TYPE_CHECKING:
    from nodes.impl.pytorch.types import PyTorchModel
    import torch


def pytorch_auto_split(img: np.ndarray, model: PyTorchModel, device: torch.device, use_fp16: bool, tiler: Tiler) -> np.ndarray:
    model = model.to(device)
    if use_fp16:
        model = model.half()
    # model = model.half() if use_fp16 else model.float()

    def upscale(img: np.ndarray, _):
        img_tensor = np2tensor(img, change_range=True)
        d_img = None
        try:
            d_img = img_tensor.to(device)
            d_img = d_img.half() if use_fp16 else d_img.float()
            result = model(d_img)
            result = tensor2np(result.detach().cpu().detach(), change_range=False, imtype=np.float32)
            del d_img
            return result
        except RuntimeError as e:
            # Check to see if its actually the CUDA out of memory error
            if "allocate" in str(e) or "CUDA" in str(e):
                # Collect garbage (clear VRAM)
                if d_img is not None:
                    try:
                        d_img.detach().cpu()
                    except:
                        pass
                    del d_img
                safe_cuda_cache_empty()
                return Split()
            else:
                # Re-raise the exception if not an OOM error
                raise

    try:
        return auto_split(img, upscale, tiler)
    finally:
        del model
        del device
        safe_cuda_cache_empty()