File size: 2,576 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#!/usr/bin/env python

import os
import sys
import time
import warnings
import torch
import numpy as np
from PIL import Image
from nodes.impl.upscale.tiler import MaxTileSize, NoTiling, Tiler
from nodes.impl.pytorch.auto_split import pytorch_auto_split
from nodes.impl.pytorch.types import PyTorchSRModel
from nodes.load_model import load_model
from nodes.log import logger


warnings.filterwarnings('ignore', category=UserWarning) # disable those for now as many backends reports tons
fp16 = False # HAT does not support fp16
device = torch.device('cuda')


def parse_tile_size_input(tile_size: int) -> Tiler:
    if tile_size == 0:
        return MaxTileSize(tile_size)
    elif tile_size == -1:
        return NoTiling()
    elif tile_size == -2:
        return MaxTileSize()
    elif tile_size < 0:
        raise ValueError(f"ChaiNNer invalid tile size: {tile_size}")
    return MaxTileSize(tile_size)


def upscale(image: Image, model: PyTorchSRModel, tile: int = 256):
    img = np.array(image)
    with torch.no_grad():
        upscaled = pytorch_auto_split(img, model=model, device=device, use_fp16=fp16, tiler=parse_tile_size_input(tile))
        return Image.fromarray(np.uint8(256 * upscaled))


if __name__ == "__main__":
    sys.argv.pop(0)
    if len(sys.argv) == 0:
        logger.error('chainner: no files specified')
        sys.exit(1)
    for modelfile in os.listdir('models'):
        try:
            modelname = os.path.splitext(modelfile)[0]
            srmodel: PyTorchSRModel = load_model(os.path.join("models", modelfile), device=device, fp16=fp16)
            logger.info(f'model="{modelname}" arch="{srmodel.__class__.__name__}" scale={srmodel.scale}')
            for imagename in sys.argv:
                if not os.path.isfile(imagename):
                    logger.error(f'image={imagename} not found')
                    continue
                inputimage = Image.open(imagename).convert('RGB')
                t0 = time.time()
                outputimage = upscale(image=inputimage, model=srmodel, tile=256)
                t1 = time.time()
                base, ext = os.path.splitext(imagename)
                outputname = f'{base}-{modelname}{ext}'
                outputimage.save(outputname)
                logger.info(f'input="{imagename}" {inputimage.size} output="{outputname}" {outputimage.size} time={t1-t0:.2f}s')
            srmodel = None
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
            # sys.exit(1)
        except Exception as e:
            logger.error(f'Error: fn={modelfile} {e}')