bilegentile's picture
Upload folder using huggingface_hub
c19ca42 verified
#!/usr/bin/env python
import os
import sys
import time
import warnings
import torch
import numpy as np
from PIL import Image
from nodes.impl.upscale.tiler import MaxTileSize, NoTiling, Tiler
from nodes.impl.pytorch.auto_split import pytorch_auto_split
from nodes.impl.pytorch.types import PyTorchSRModel
from nodes.load_model import load_model
from nodes.log import logger
warnings.filterwarnings('ignore', category=UserWarning) # disable those for now as many backends reports tons
fp16 = False # HAT does not support fp16
device = torch.device('cuda')
def parse_tile_size_input(tile_size: int) -> Tiler:
if tile_size == 0:
return MaxTileSize(tile_size)
elif tile_size == -1:
return NoTiling()
elif tile_size == -2:
return MaxTileSize()
elif tile_size < 0:
raise ValueError(f"ChaiNNer invalid tile size: {tile_size}")
return MaxTileSize(tile_size)
def upscale(image: Image, model: PyTorchSRModel, tile: int = 256):
img = np.array(image)
with torch.no_grad():
upscaled = pytorch_auto_split(img, model=model, device=device, use_fp16=fp16, tiler=parse_tile_size_input(tile))
return Image.fromarray(np.uint8(256 * upscaled))
if __name__ == "__main__":
sys.argv.pop(0)
if len(sys.argv) == 0:
logger.error('chainner: no files specified')
sys.exit(1)
for modelfile in os.listdir('models'):
try:
modelname = os.path.splitext(modelfile)[0]
srmodel: PyTorchSRModel = load_model(os.path.join("models", modelfile), device=device, fp16=fp16)
logger.info(f'model="{modelname}" arch="{srmodel.__class__.__name__}" scale={srmodel.scale}')
for imagename in sys.argv:
if not os.path.isfile(imagename):
logger.error(f'image={imagename} not found')
continue
inputimage = Image.open(imagename).convert('RGB')
t0 = time.time()
outputimage = upscale(image=inputimage, model=srmodel, tile=256)
t1 = time.time()
base, ext = os.path.splitext(imagename)
outputname = f'{base}-{modelname}{ext}'
outputimage.save(outputname)
logger.info(f'input="{imagename}" {inputimage.size} output="{outputname}" {outputimage.size} time={t1-t0:.2f}s')
srmodel = None
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
# sys.exit(1)
except Exception as e:
logger.error(f'Error: fn={modelfile} {e}')