ESRGAN / app.py
DakuSir's picture
last
8bd92cd verified
raw
history blame contribute delete
2.39 kB
import torch
from PIL import Image
from RealESRGAN import RealESRGAN
import gradio as gr
import os
import spaces
if torch.cuda.is_available():
print(f"CUDA is available. GPU: {torch.cuda.get_device_name(0)}")
device = torch.device("cuda")
else:
print("CUDA is not available. Using CPU.")
device = torch.device("cpu")
class LazyRealESRGAN:
def __init__(self, device, scale):
self.device = device
self.scale = scale
self.model = None
def load_model(self):
if self.model is None:
self.model = RealESRGAN(self.device, scale=self.scale)
self.model.load_weights(f'weights/RealESRGAN_x{self.scale}.pth', download=True)
def predict(self, img):
self.load_model()
return self.model.predict(img)
model2 = LazyRealESRGAN(device, scale=2)
model4 = LazyRealESRGAN(device, scale=4)
model8 = LazyRealESRGAN(device, scale=8)
@spaces.GPU
def inference(image, size):
if image is None:
raise gr.Error("Image not uploaded")
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
if size == '2x':
result = model2.predict(image.convert('RGB'))
elif size == '4x':
result = model4.predict(image.convert('RGB'))
else:
width, height = image.size
if width >= 5000 or height >= 5000:
raise gr.Error("The image is too large.")
result = model8.predict(image.convert('RGB'))
print(f"Image size ({device}): {size} ... OK")
return result
except torch.cuda.OutOfMemoryError:
raise gr.Error("GPU out of memory. Try a smaller image or lower upscaling factor.")
except Exception as e:
raise gr.Error(f"An error occurred: {str(e)}")
title = "Face Real ESRGAN UpScale: 2x 4x 8x"
description = "This is an unofficial demo for Real-ESRGAN. Scales the resolution of a photo. This model shows better results on faces compared to the original version."
iface = gr.Interface(
inference,
[
gr.Image(type="pil"),
gr.Radio(["2x", "4x", "8x"], type="value", value="2x", label="Resolution model")
],
gr.Image(type="pil", label="Output"),
title=title,
description=description,
flagging_mode="never",
cache_examples=True
)
if __name__ == "__main__":
iface.launch(debug=True, show_error=True)