Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
import os | |
import sys | |
import time | |
import warnings | |
import torch | |
import numpy as np | |
from PIL import Image | |
from nodes.impl.upscale.tiler import MaxTileSize, NoTiling, Tiler | |
from nodes.impl.pytorch.auto_split import pytorch_auto_split | |
from nodes.impl.pytorch.types import PyTorchSRModel | |
from nodes.load_model import load_model | |
from nodes.log import logger | |
warnings.filterwarnings('ignore', category=UserWarning) # disable those for now as many backends reports tons | |
fp16 = False # HAT does not support fp16 | |
device = torch.device('cuda') | |
def parse_tile_size_input(tile_size: int) -> Tiler: | |
if tile_size == 0: | |
return MaxTileSize(tile_size) | |
elif tile_size == -1: | |
return NoTiling() | |
elif tile_size == -2: | |
return MaxTileSize() | |
elif tile_size < 0: | |
raise ValueError(f"ChaiNNer invalid tile size: {tile_size}") | |
return MaxTileSize(tile_size) | |
def upscale(image: Image, model: PyTorchSRModel, tile: int = 256): | |
img = np.array(image) | |
with torch.no_grad(): | |
upscaled = pytorch_auto_split(img, model=model, device=device, use_fp16=fp16, tiler=parse_tile_size_input(tile)) | |
return Image.fromarray(np.uint8(256 * upscaled)) | |
if __name__ == "__main__": | |
sys.argv.pop(0) | |
if len(sys.argv) == 0: | |
logger.error('chainner: no files specified') | |
sys.exit(1) | |
for modelfile in os.listdir('models'): | |
try: | |
modelname = os.path.splitext(modelfile)[0] | |
srmodel: PyTorchSRModel = load_model(os.path.join("models", modelfile), device=device, fp16=fp16) | |
logger.info(f'model="{modelname}" arch="{srmodel.__class__.__name__}" scale={srmodel.scale}') | |
for imagename in sys.argv: | |
if not os.path.isfile(imagename): | |
logger.error(f'image={imagename} not found') | |
continue | |
inputimage = Image.open(imagename).convert('RGB') | |
t0 = time.time() | |
outputimage = upscale(image=inputimage, model=srmodel, tile=256) | |
t1 = time.time() | |
base, ext = os.path.splitext(imagename) | |
outputname = f'{base}-{modelname}{ext}' | |
outputimage.save(outputname) | |
logger.info(f'input="{imagename}" {inputimage.size} output="{outputname}" {outputimage.size} time={t1-t0:.2f}s') | |
srmodel = None | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
# sys.exit(1) | |
except Exception as e: | |
logger.error(f'Error: fn={modelfile} {e}') | |