File size: 2,146 Bytes
33ac1eb
 
 
 
 
 
 
 
 
 
 
 
840ed11
c373d06
 
33ac1eb
 
 
c373d06
 
 
 
 
 
 
a58728b
 
 
 
33ac1eb
 
a58728b
64c514e
33ac1eb
 
64c514e
 
 
33ac1eb
a58728b
 
33ac1eb
840ed11
 
 
 
 
33ac1eb
840ed11
 
 
64c514e
840ed11
 
 
64c514e
33ac1eb
840ed11
33ac1eb
64c514e
a58728b
 
 
 
 
 
840ed11
64c514e
 
33ac1eb
64c514e
 
 
 
 
 
 
 
 
 
 
 
33ac1eb
64c514e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import numpy as np
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import SRCNNModel, pred_SRCNN
from PIL import Image


title = "Super Resolution with CNN"
description = """

Your low resolution image will be reconstructed to high resolution with a scale of 2 with a convolutional neural network!<br>

Detailed training and dataset can be found on my [github repo](https://github.com/susuhu/super-resolution).<br>

"""

article = """
<div style='margin:20px auto;'>
<p>Sources:<p>
<p>๐Ÿ“œ <a href="https://arxiv.org/abs/1501.00092">Image Super-Resolution Using Deep Convolutional Networks</a></p>
<p>๐Ÿ“ฆ Dataset <a href="https://github.com/eugenesiow/super-image-data">this GitHub repo</a></p>
</div>
"""
examples = [
    ["LR_image.png"],
    ["barbara.png"],
]

# load model
# print("Loading  SRCNN model...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SRCNNModel().to(device)
model.load_state_dict(
    torch.load("SRCNNmodel_trained.pt", map_location=torch.device(device))
)
model.eval()
# print("SRCNN model loaded!")


# def image_grid(imgs, rows, cols):
#     '''
#     imgs:list of PILImage
#     '''
#     assert len(imgs) == rows*cols

#     w, h = imgs[0].size
#     grid = Image.new('RGB', size=(cols*w, rows*h))
#     grid_w, grid_h = grid.size

#     for i, img in enumerate(imgs):
#         grid.paste(img, box=(i%cols*w, i//cols*h))
#     return grid


def sepia(image):
    # gradio open image as np array
    image = Image.fromarray(image, mode="RGB")

    # prediction
    with torch.no_grad():
        out_final, image_bicubic, image = pred_SRCNN(
            model=model, image=image, device=device
        )
    # grid = image_grid([out_final,image_bicubic],1,2)
    return out_final, image_bicubic


demo = gr.Interface(
    fn=sepia,
    inputs=gr.inputs.Image(label="Upload image"),
    outputs=[
        gr.outputs.Image(label="Convolutional neural network"),
        gr.outputs.Image(label="Bicubic interpoloation"),
    ],
    title=title,
    description=description,
    article=article,
    examples=examples,
)

demo.launch()