Thibaud Cheruy
commited on
Commit
·
92d45d2
1
Parent(s):
56a1b0b
New: Add SRGAN Space
Browse files- .gitattributes +1 -0
- .gitignore +2 -0
- README.md +1 -0
- __pycache__/imgproc.cpython-310.pyc +0 -0
- __pycache__/inference.cpython-310.pyc +0 -0
- __pycache__/model.cpython-310.pyc +0 -0
- __pycache__/utils.cpython-310.pyc +0 -0
- app.py +111 -0
- examples/bird.png +0 -0
- examples/butterfly.png +0 -0
- examples/comic.png +0 -0
- examples/gray.png +0 -0
- examples/man.png +0 -0
- imgproc.py +575 -0
- inference.py +101 -0
- model.py +251 -0
- utils.py +168 -0
- weights/SRGAN_x4-ImageNet-8c4a7569.pth.tar +3 -0
.gitattributes
CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*.pth.tar filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
venv/
|
2 |
+
.idea/
|
README.md
CHANGED
@@ -8,6 +8,7 @@ sdk_version: 3.16.1
|
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
|
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
11 |
+
python_version: 3.10.3
|
12 |
---
|
13 |
|
14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
__pycache__/imgproc.cpython-310.pyc
ADDED
Binary file (13.5 kB). View file
|
|
__pycache__/inference.cpython-310.pyc
ADDED
Binary file (2.47 kB). View file
|
|
__pycache__/model.cpython-310.pyc
ADDED
Binary file (6.9 kB). View file
|
|
__pycache__/utils.cpython-310.pyc
ADDED
Binary file (5.46 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import cv2
|
4 |
+
|
5 |
+
import imgproc
|
6 |
+
from imgproc import image_to_tensor
|
7 |
+
from inference import choice_device, build_model
|
8 |
+
from utils import load_state_dict
|
9 |
+
|
10 |
+
model = "srresnet_x4"
|
11 |
+
|
12 |
+
device = choice_device("cpu")
|
13 |
+
|
14 |
+
# Initialize the model
|
15 |
+
sr_model = build_model(model, device)
|
16 |
+
print(f"Build {model} model successfully.")
|
17 |
+
|
18 |
+
# Load model weights
|
19 |
+
sr_model = load_state_dict(sr_model, "weights/SRGAN_x4-ImageNet-8c4a7569.pth.tar")
|
20 |
+
print(f"Load `{model}` model weights successfully.")
|
21 |
+
|
22 |
+
# Start the verification mode of the model.
|
23 |
+
sr_model.eval()
|
24 |
+
|
25 |
+
def downscale(image):
|
26 |
+
(width, height, colors) = image.shape
|
27 |
+
|
28 |
+
new_height = int(60 * width / height)
|
29 |
+
|
30 |
+
return cv2.resize(image, (60, new_height), interpolation=cv2.INTER_AREA)
|
31 |
+
|
32 |
+
|
33 |
+
def preprocess(image):
|
34 |
+
image = image / 255.0
|
35 |
+
|
36 |
+
# Convert image data to pytorch format data
|
37 |
+
tensor = image_to_tensor(image, False, False).unsqueeze_(0)
|
38 |
+
|
39 |
+
# Transfer tensor channel image format data to CUDA device
|
40 |
+
tensor = tensor.to(device="cpu", memory_format=torch.channels_last, non_blocking=True)
|
41 |
+
|
42 |
+
return tensor
|
43 |
+
|
44 |
+
def processHighRes(image):
|
45 |
+
if image is None:
|
46 |
+
raise gr.Error("Please enter an image")
|
47 |
+
downscaled = downscale(image)
|
48 |
+
lr_tensor = preprocess(downscaled)
|
49 |
+
|
50 |
+
# Use the model to generate super-resolved images
|
51 |
+
with torch.no_grad():
|
52 |
+
sr_tensor = sr_model(lr_tensor)
|
53 |
+
|
54 |
+
# Save image
|
55 |
+
sr_image = imgproc.tensor_to_image(sr_tensor, False, False)
|
56 |
+
|
57 |
+
return [downscaled, sr_image]
|
58 |
+
|
59 |
+
def processLowRes(image):
|
60 |
+
if image is None:
|
61 |
+
raise gr.Error("Please enter an image")
|
62 |
+
|
63 |
+
(width, height, colors) = image.shape
|
64 |
+
|
65 |
+
if width > 150 or height > 150:
|
66 |
+
raise gr.Error("Image is too big")
|
67 |
+
|
68 |
+
lr_tensor = preprocess(image)
|
69 |
+
|
70 |
+
# Use the model to generate super-resolved images
|
71 |
+
with torch.no_grad():
|
72 |
+
sr_tensor = sr_model(lr_tensor)
|
73 |
+
|
74 |
+
# Save image
|
75 |
+
sr_image = imgproc.tensor_to_image(sr_tensor, False, False)
|
76 |
+
|
77 |
+
return [sr_image]
|
78 |
+
|
79 |
+
description = """<p style='text-align: center'> <a href='https://arxiv.org/abs/1609.04802' target='_blank'>Paper</a> | <a href=https://github.com/Lornatang/SRGAN-PyTorch target='_blank'>GitHub</a></p>"""
|
80 |
+
|
81 |
+
with gr.Blocks() as demo:
|
82 |
+
gr.Markdown("# **<p align='center'>SRGAN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network</p>**")
|
83 |
+
gr.Markdown(description)
|
84 |
+
|
85 |
+
with gr.Tab("From high res"):
|
86 |
+
high_res_input = gr.Image(label="High-res source image", show_label=True)
|
87 |
+
with gr.Row():
|
88 |
+
low_res_output = gr.Image(label="Low-res image")
|
89 |
+
srgan_output = gr.Image(label="SRGAN Output")
|
90 |
+
high_res_button = gr.Button("Process")
|
91 |
+
|
92 |
+
with gr.Tab("From low res"):
|
93 |
+
low_res_input = gr.Image(label="Low-res source image", show_label=True)
|
94 |
+
srgan_upscale = gr.Image(label="SRGAN Output")
|
95 |
+
low_res_button = gr.Button("Process")
|
96 |
+
|
97 |
+
gr.Examples(
|
98 |
+
examples=["examples/bird.png", "examples/butterfly.png", "examples/comic.png", "examples/gray.png",
|
99 |
+
"examples/man.png"],
|
100 |
+
inputs=[high_res_input],
|
101 |
+
outputs=[low_res_output, srgan_output],
|
102 |
+
fn=processHighRes
|
103 |
+
)
|
104 |
+
|
105 |
+
high_res_button.click(processHighRes, inputs=[high_res_input], outputs=[low_res_output, srgan_output])
|
106 |
+
low_res_button.click(processLowRes, inputs=[low_res_input], outputs=[srgan_upscale])
|
107 |
+
|
108 |
+
gr.Markdown("<p style='text-align: center'>Made for the 2022-2023 Grenoble-INP Phelma Image analysis course, by Thibaud CHERUY, Clément DEBUY & Yassine El Khanoussi.</p>")
|
109 |
+
|
110 |
+
|
111 |
+
demo.launch()
|
examples/bird.png
ADDED
![]() |
examples/butterfly.png
ADDED
![]() |
examples/comic.png
ADDED
![]() |
examples/gray.png
ADDED
![]() |
examples/man.png
ADDED
![]() |
imgproc.py
ADDED
@@ -0,0 +1,575 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7 |
+
#
|
8 |
+
# Unless required by applicable law or agreed to in writing, software
|
9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11 |
+
# See the License for the specific language governing permissions and
|
12 |
+
# limitations under the License.
|
13 |
+
# ==============================================================================
|
14 |
+
import math
|
15 |
+
import random
|
16 |
+
from typing import Any
|
17 |
+
|
18 |
+
import cv2
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
from numpy import ndarray
|
22 |
+
from torch import Tensor
|
23 |
+
|
24 |
+
__all__ = [
|
25 |
+
"image_to_tensor", "tensor_to_image",
|
26 |
+
"image_resize", "preprocess_one_image",
|
27 |
+
"expand_y", "rgb_to_ycbcr", "bgr_to_ycbcr", "ycbcr_to_bgr", "ycbcr_to_rgb",
|
28 |
+
"rgb_to_ycbcr_torch", "bgr_to_ycbcr_torch",
|
29 |
+
"center_crop", "random_crop", "random_rotate", "random_vertically_flip", "random_horizontally_flip",
|
30 |
+
]
|
31 |
+
|
32 |
+
|
33 |
+
# Code reference `https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/matlab_functions.py`
|
34 |
+
def _cubic(x: Any) -> Any:
|
35 |
+
"""Implementation of `cubic` function in Matlab under Python language.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
x: Element vector.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
Bicubic interpolation
|
42 |
+
|
43 |
+
"""
|
44 |
+
absx = torch.abs(x)
|
45 |
+
absx2 = absx ** 2
|
46 |
+
absx3 = absx ** 3
|
47 |
+
return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + (
|
48 |
+
-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (
|
49 |
+
((absx > 1) * (absx <= 2)).type_as(absx))
|
50 |
+
|
51 |
+
|
52 |
+
# Code reference `https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/matlab_functions.py`
|
53 |
+
def _calculate_weights_indices(in_length: int,
|
54 |
+
out_length: int,
|
55 |
+
scale: float,
|
56 |
+
kernel_width: int,
|
57 |
+
antialiasing: bool) -> [np.ndarray, np.ndarray, int, int]:
|
58 |
+
"""Implementation of `calculate_weights_indices` function in Matlab under Python language.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
in_length (int): Input length.
|
62 |
+
out_length (int): Output length.
|
63 |
+
scale (float): Scale factor.
|
64 |
+
kernel_width (int): Kernel width.
|
65 |
+
antialiasing (bool): Whether to apply antialiasing when down-sampling operations.
|
66 |
+
Caution: Bicubic down-sampling in PIL uses antialiasing by default.
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
weights, indices, sym_len_s, sym_len_e
|
70 |
+
|
71 |
+
"""
|
72 |
+
if (scale < 1) and antialiasing:
|
73 |
+
# Use a modified kernel (larger kernel width) to simultaneously
|
74 |
+
# interpolate and antialiasing
|
75 |
+
kernel_width = kernel_width / scale
|
76 |
+
|
77 |
+
# Output-space coordinates
|
78 |
+
x = torch.linspace(1, out_length, out_length)
|
79 |
+
|
80 |
+
# Input-space coordinates. Calculate the inverse mapping such that 0.5
|
81 |
+
# in output space maps to 0.5 in input space, and 0.5 + scale in output
|
82 |
+
# space maps to 1.5 in input space.
|
83 |
+
u = x / scale + 0.5 * (1 - 1 / scale)
|
84 |
+
|
85 |
+
# What is the left-most pixel that can be involved in the computation?
|
86 |
+
left = torch.floor(u - kernel_width / 2)
|
87 |
+
|
88 |
+
# What is the maximum number of pixels that can be involved in the
|
89 |
+
# computation? Note: it's OK to use an extra pixel here; if the
|
90 |
+
# corresponding weights are all zero, it will be eliminated at the end
|
91 |
+
# of this function.
|
92 |
+
p = math.ceil(kernel_width) + 2
|
93 |
+
|
94 |
+
# The indices of the input pixels involved in computing the k-th output
|
95 |
+
# pixel are in row k of the indices matrix.
|
96 |
+
indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
|
97 |
+
out_length, p)
|
98 |
+
|
99 |
+
# The weights used to compute the k-th output pixel are in row k of the
|
100 |
+
# weights matrix.
|
101 |
+
distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
|
102 |
+
|
103 |
+
# apply cubic kernel
|
104 |
+
if (scale < 1) and antialiasing:
|
105 |
+
weights = scale * _cubic(distance_to_center * scale)
|
106 |
+
else:
|
107 |
+
weights = _cubic(distance_to_center)
|
108 |
+
|
109 |
+
# Normalize the weights matrix so that each row sums to 1.
|
110 |
+
weights_sum = torch.sum(weights, 1).view(out_length, 1)
|
111 |
+
weights = weights / weights_sum.expand(out_length, p)
|
112 |
+
|
113 |
+
# If a column in weights is all zero, get rid of it. only consider the
|
114 |
+
# first and last column.
|
115 |
+
weights_zero_tmp = torch.sum((weights == 0), 0)
|
116 |
+
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
|
117 |
+
indices = indices.narrow(1, 1, p - 2)
|
118 |
+
weights = weights.narrow(1, 1, p - 2)
|
119 |
+
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
|
120 |
+
indices = indices.narrow(1, 0, p - 2)
|
121 |
+
weights = weights.narrow(1, 0, p - 2)
|
122 |
+
weights = weights.contiguous()
|
123 |
+
indices = indices.contiguous()
|
124 |
+
sym_len_s = -indices.min() + 1
|
125 |
+
sym_len_e = indices.max() - in_length
|
126 |
+
indices = indices + sym_len_s - 1
|
127 |
+
return weights, indices, int(sym_len_s), int(sym_len_e)
|
128 |
+
|
129 |
+
|
130 |
+
def image_to_tensor(image: ndarray, range_norm: bool, half: bool) -> Tensor:
|
131 |
+
"""Convert the image data type to the Tensor (NCWH) data type supported by PyTorch
|
132 |
+
|
133 |
+
Args:
|
134 |
+
image (np.ndarray): The image data read by ``OpenCV.imread``, the data range is [0,255] or [0, 1]
|
135 |
+
range_norm (bool): Scale [0, 1] data to between [-1, 1]
|
136 |
+
half (bool): Whether to convert torch.float32 similarly to torch.half type
|
137 |
+
|
138 |
+
Returns:
|
139 |
+
tensor (Tensor): Data types supported by PyTorch
|
140 |
+
|
141 |
+
Examples:
|
142 |
+
>>> example_image = cv2.imread("lr_image.bmp")
|
143 |
+
>>> example_tensor = image_to_tensor(example_image, range_norm=True, half=False)
|
144 |
+
|
145 |
+
"""
|
146 |
+
# Convert image data type to Tensor data type
|
147 |
+
tensor = torch.from_numpy(np.ascontiguousarray(image)).permute(2, 0, 1).float()
|
148 |
+
|
149 |
+
# Scale the image data from [0, 1] to [-1, 1]
|
150 |
+
if range_norm:
|
151 |
+
tensor = tensor.mul(2.0).sub(1.0)
|
152 |
+
|
153 |
+
# Convert torch.float32 image data type to torch.half image data type
|
154 |
+
if half:
|
155 |
+
tensor = tensor.half()
|
156 |
+
|
157 |
+
return tensor
|
158 |
+
|
159 |
+
|
160 |
+
def tensor_to_image(tensor: Tensor, range_norm: bool, half: bool) -> Any:
|
161 |
+
"""Convert the Tensor(NCWH) data type supported by PyTorch to the np.ndarray(WHC) image data type
|
162 |
+
|
163 |
+
Args:
|
164 |
+
tensor (Tensor): Data types supported by PyTorch (NCHW), the data range is [0, 1]
|
165 |
+
range_norm (bool): Scale [-1, 1] data to between [0, 1]
|
166 |
+
half (bool): Whether to convert torch.float32 similarly to torch.half type.
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
image (np.ndarray): Data types supported by PIL or OpenCV
|
170 |
+
|
171 |
+
Examples:
|
172 |
+
>>> example_image = cv2.imread("lr_image.bmp")
|
173 |
+
>>> example_tensor = image_to_tensor(example_image, range_norm=False, half=False)
|
174 |
+
|
175 |
+
"""
|
176 |
+
if range_norm:
|
177 |
+
tensor = tensor.add(1.0).div(2.0)
|
178 |
+
if half:
|
179 |
+
tensor = tensor.half()
|
180 |
+
|
181 |
+
image = tensor.squeeze(0).permute(1, 2, 0).mul(255).clamp(0, 255).cpu().numpy().astype("uint8")
|
182 |
+
|
183 |
+
return image
|
184 |
+
|
185 |
+
|
186 |
+
def preprocess_one_image(image_path: str, device: torch.device) -> Tensor:
|
187 |
+
image = cv2.imread(image_path).astype(np.float32) / 255.0
|
188 |
+
|
189 |
+
# BGR to RGB
|
190 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
191 |
+
|
192 |
+
# Convert image data to pytorch format data
|
193 |
+
tensor = image_to_tensor(image, False, False).unsqueeze_(0)
|
194 |
+
|
195 |
+
# Transfer tensor channel image format data to CUDA device
|
196 |
+
tensor = tensor.to(device=device, memory_format=torch.channels_last, non_blocking=True)
|
197 |
+
|
198 |
+
return tensor
|
199 |
+
|
200 |
+
|
201 |
+
# Code reference `https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/matlab_functions.py`
|
202 |
+
def image_resize(image: Any, scale_factor: float, antialiasing: bool = True) -> Any:
|
203 |
+
"""Implementation of `imresize` function in Matlab under Python language.
|
204 |
+
|
205 |
+
Args:
|
206 |
+
image: The input image.
|
207 |
+
scale_factor (float): Scale factor. The same scale applies for both height and width.
|
208 |
+
antialiasing (bool): Whether to apply antialiasing when down-sampling operations.
|
209 |
+
Caution: Bicubic down-sampling in `PIL` uses antialiasing by default. Default: ``True``.
|
210 |
+
|
211 |
+
Returns:
|
212 |
+
out_2 (np.ndarray): Output image with shape (c, h, w), [0, 1] range, w/o round
|
213 |
+
|
214 |
+
"""
|
215 |
+
squeeze_flag = False
|
216 |
+
if type(image).__module__ == np.__name__: # numpy type
|
217 |
+
numpy_type = True
|
218 |
+
if image.ndim == 2:
|
219 |
+
image = image[:, :, None]
|
220 |
+
squeeze_flag = True
|
221 |
+
image = torch.from_numpy(image.transpose(2, 0, 1)).float()
|
222 |
+
else:
|
223 |
+
numpy_type = False
|
224 |
+
if image.ndim == 2:
|
225 |
+
image = image.unsqueeze(0)
|
226 |
+
squeeze_flag = True
|
227 |
+
|
228 |
+
in_c, in_h, in_w = image.size()
|
229 |
+
out_h, out_w = math.ceil(in_h * scale_factor), math.ceil(in_w * scale_factor)
|
230 |
+
kernel_width = 4
|
231 |
+
|
232 |
+
# get weights and indices
|
233 |
+
weights_h, indices_h, sym_len_hs, sym_len_he = _calculate_weights_indices(in_h, out_h, scale_factor, kernel_width,
|
234 |
+
antialiasing)
|
235 |
+
weights_w, indices_w, sym_len_ws, sym_len_we = _calculate_weights_indices(in_w, out_w, scale_factor, kernel_width,
|
236 |
+
antialiasing)
|
237 |
+
# process H dimension
|
238 |
+
# symmetric copying
|
239 |
+
img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
|
240 |
+
img_aug.narrow(1, sym_len_hs, in_h).copy_(image)
|
241 |
+
|
242 |
+
sym_patch = image[:, :sym_len_hs, :]
|
243 |
+
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
244 |
+
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
245 |
+
img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
|
246 |
+
|
247 |
+
sym_patch = image[:, -sym_len_he:, :]
|
248 |
+
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
249 |
+
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
250 |
+
img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
|
251 |
+
|
252 |
+
out_1 = torch.FloatTensor(in_c, out_h, in_w)
|
253 |
+
kernel_width = weights_h.size(1)
|
254 |
+
for i in range(out_h):
|
255 |
+
idx = int(indices_h[i][0])
|
256 |
+
for j in range(in_c):
|
257 |
+
out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
|
258 |
+
|
259 |
+
# process W dimension
|
260 |
+
# symmetric copying
|
261 |
+
out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
|
262 |
+
out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
|
263 |
+
|
264 |
+
sym_patch = out_1[:, :, :sym_len_ws]
|
265 |
+
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
266 |
+
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
267 |
+
out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
|
268 |
+
|
269 |
+
sym_patch = out_1[:, :, -sym_len_we:]
|
270 |
+
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
271 |
+
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
272 |
+
out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
|
273 |
+
|
274 |
+
out_2 = torch.FloatTensor(in_c, out_h, out_w)
|
275 |
+
kernel_width = weights_w.size(1)
|
276 |
+
for i in range(out_w):
|
277 |
+
idx = int(indices_w[i][0])
|
278 |
+
for j in range(in_c):
|
279 |
+
out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
|
280 |
+
|
281 |
+
if squeeze_flag:
|
282 |
+
out_2 = out_2.squeeze(0)
|
283 |
+
if numpy_type:
|
284 |
+
out_2 = out_2.numpy()
|
285 |
+
if not squeeze_flag:
|
286 |
+
out_2 = out_2.transpose(1, 2, 0)
|
287 |
+
|
288 |
+
return out_2
|
289 |
+
|
290 |
+
|
291 |
+
def expand_y(image: np.ndarray) -> np.ndarray:
|
292 |
+
"""Convert BGR channel to YCbCr format,
|
293 |
+
and expand Y channel data in YCbCr, from HW to HWC
|
294 |
+
|
295 |
+
Args:
|
296 |
+
image (np.ndarray): Y channel image data
|
297 |
+
|
298 |
+
Returns:
|
299 |
+
y_image (np.ndarray): Y-channel image data in HWC form
|
300 |
+
|
301 |
+
"""
|
302 |
+
# Normalize image data to [0, 1]
|
303 |
+
image = image.astype(np.float32) / 255.
|
304 |
+
|
305 |
+
# Convert BGR to YCbCr, and extract only Y channel
|
306 |
+
y_image = bgr_to_ycbcr(image, only_use_y_channel=True)
|
307 |
+
|
308 |
+
# Expand Y channel
|
309 |
+
y_image = y_image[..., None]
|
310 |
+
|
311 |
+
# Normalize the image data to [0, 255]
|
312 |
+
y_image = y_image.astype(np.float64) * 255.0
|
313 |
+
|
314 |
+
return y_image
|
315 |
+
|
316 |
+
|
317 |
+
def rgb_to_ycbcr(image: np.ndarray, only_use_y_channel: bool) -> np.ndarray:
|
318 |
+
"""Implementation of rgb2ycbcr function in Matlab under Python language
|
319 |
+
|
320 |
+
Args:
|
321 |
+
image (np.ndarray): Image input in RGB format.
|
322 |
+
only_use_y_channel (bool): Extract Y channel separately
|
323 |
+
|
324 |
+
Returns:
|
325 |
+
image (np.ndarray): YCbCr image array data
|
326 |
+
|
327 |
+
"""
|
328 |
+
if only_use_y_channel:
|
329 |
+
image = np.dot(image, [65.481, 128.553, 24.966]) + 16.0
|
330 |
+
else:
|
331 |
+
image = np.matmul(image, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [
|
332 |
+
16, 128, 128]
|
333 |
+
|
334 |
+
image /= 255.
|
335 |
+
image = image.astype(np.float32)
|
336 |
+
|
337 |
+
return image
|
338 |
+
|
339 |
+
|
340 |
+
def bgr_to_ycbcr(image: np.ndarray, only_use_y_channel: bool) -> np.ndarray:
|
341 |
+
"""Implementation of bgr2ycbcr function in Matlab under Python language.
|
342 |
+
|
343 |
+
Args:
|
344 |
+
image (np.ndarray): Image input in BGR format
|
345 |
+
only_use_y_channel (bool): Extract Y channel separately
|
346 |
+
|
347 |
+
Returns:
|
348 |
+
image (np.ndarray): YCbCr image array data
|
349 |
+
|
350 |
+
"""
|
351 |
+
if only_use_y_channel:
|
352 |
+
image = np.dot(image, [24.966, 128.553, 65.481]) + 16.0
|
353 |
+
else:
|
354 |
+
image = np.matmul(image, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [
|
355 |
+
16, 128, 128]
|
356 |
+
|
357 |
+
image /= 255.
|
358 |
+
image = image.astype(np.float32)
|
359 |
+
|
360 |
+
return image
|
361 |
+
|
362 |
+
|
363 |
+
def ycbcr_to_rgb(image: np.ndarray) -> np.ndarray:
|
364 |
+
"""Implementation of ycbcr2rgb function in Matlab under Python language.
|
365 |
+
|
366 |
+
Args:
|
367 |
+
image (np.ndarray): Image input in YCbCr format.
|
368 |
+
|
369 |
+
Returns:
|
370 |
+
image (np.ndarray): RGB image array data
|
371 |
+
|
372 |
+
"""
|
373 |
+
image_dtype = image.dtype
|
374 |
+
image *= 255.
|
375 |
+
|
376 |
+
image = np.matmul(image, [[0.00456621, 0.00456621, 0.00456621],
|
377 |
+
[0, -0.00153632, 0.00791071],
|
378 |
+
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
|
379 |
+
|
380 |
+
image /= 255.
|
381 |
+
image = image.astype(image_dtype)
|
382 |
+
|
383 |
+
return image
|
384 |
+
|
385 |
+
|
386 |
+
def ycbcr_to_bgr(image: np.ndarray) -> np.ndarray:
|
387 |
+
"""Implementation of ycbcr2bgr function in Matlab under Python language.
|
388 |
+
|
389 |
+
Args:
|
390 |
+
image (np.ndarray): Image input in YCbCr format.
|
391 |
+
|
392 |
+
Returns:
|
393 |
+
image (np.ndarray): BGR image array data
|
394 |
+
|
395 |
+
"""
|
396 |
+
image_dtype = image.dtype
|
397 |
+
image *= 255.
|
398 |
+
|
399 |
+
image = np.matmul(image, [[0.00456621, 0.00456621, 0.00456621],
|
400 |
+
[0.00791071, -0.00153632, 0],
|
401 |
+
[0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921]
|
402 |
+
|
403 |
+
image /= 255.
|
404 |
+
image = image.astype(image_dtype)
|
405 |
+
|
406 |
+
return image
|
407 |
+
|
408 |
+
|
409 |
+
def rgb_to_ycbcr_torch(tensor: Tensor, only_use_y_channel: bool) -> Tensor:
|
410 |
+
"""Implementation of rgb2ycbcr function in Matlab under PyTorch
|
411 |
+
|
412 |
+
References from:`https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion`
|
413 |
+
|
414 |
+
Args:
|
415 |
+
tensor (Tensor): Image data in PyTorch format
|
416 |
+
only_use_y_channel (bool): Extract only Y channel
|
417 |
+
|
418 |
+
Returns:
|
419 |
+
tensor (Tensor): YCbCr image data in PyTorch format
|
420 |
+
|
421 |
+
"""
|
422 |
+
if only_use_y_channel:
|
423 |
+
weight = Tensor([[65.481], [128.553], [24.966]]).to(tensor)
|
424 |
+
tensor = torch.matmul(tensor.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0
|
425 |
+
else:
|
426 |
+
weight = Tensor([[65.481, -37.797, 112.0],
|
427 |
+
[128.553, -74.203, -93.786],
|
428 |
+
[24.966, 112.0, -18.214]]).to(tensor)
|
429 |
+
bias = Tensor([16, 128, 128]).view(1, 3, 1, 1).to(tensor)
|
430 |
+
tensor = torch.matmul(tensor.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias
|
431 |
+
|
432 |
+
tensor /= 255.
|
433 |
+
|
434 |
+
return tensor
|
435 |
+
|
436 |
+
|
437 |
+
def bgr_to_ycbcr_torch(tensor: Tensor, only_use_y_channel: bool) -> Tensor:
|
438 |
+
"""Implementation of bgr2ycbcr function in Matlab under PyTorch
|
439 |
+
|
440 |
+
References from:`https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion`
|
441 |
+
|
442 |
+
Args:
|
443 |
+
tensor (Tensor): Image data in PyTorch format
|
444 |
+
only_use_y_channel (bool): Extract only Y channel
|
445 |
+
|
446 |
+
Returns:
|
447 |
+
tensor (Tensor): YCbCr image data in PyTorch format
|
448 |
+
|
449 |
+
"""
|
450 |
+
if only_use_y_channel:
|
451 |
+
weight = Tensor([[24.966], [128.553], [65.481]]).to(tensor)
|
452 |
+
tensor = torch.matmul(tensor.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0
|
453 |
+
else:
|
454 |
+
weight = Tensor([[24.966, 112.0, -18.214],
|
455 |
+
[128.553, -74.203, -93.786],
|
456 |
+
[65.481, -37.797, 112.0]]).to(tensor)
|
457 |
+
bias = Tensor([16, 128, 128]).view(1, 3, 1, 1).to(tensor)
|
458 |
+
tensor = torch.matmul(tensor.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias
|
459 |
+
|
460 |
+
tensor /= 255.
|
461 |
+
|
462 |
+
return tensor
|
463 |
+
|
464 |
+
|
465 |
+
def center_crop(image: np.ndarray, image_size: int) -> np.ndarray:
|
466 |
+
"""Crop small image patches from one image center area.
|
467 |
+
|
468 |
+
Args:
|
469 |
+
image (np.ndarray): The input image for `OpenCV.imread`.
|
470 |
+
image_size (int): The size of the captured image area.
|
471 |
+
|
472 |
+
Returns:
|
473 |
+
patch_image (np.ndarray): Small patch image
|
474 |
+
|
475 |
+
"""
|
476 |
+
image_height, image_width = image.shape[:2]
|
477 |
+
|
478 |
+
# Just need to find the top and left coordinates of the image
|
479 |
+
top = (image_height - image_size) // 2
|
480 |
+
left = (image_width - image_size) // 2
|
481 |
+
|
482 |
+
# Crop image patch
|
483 |
+
patch_image = image[top:top + image_size, left:left + image_size, ...]
|
484 |
+
|
485 |
+
return patch_image
|
486 |
+
|
487 |
+
|
488 |
+
def random_crop(image: np.ndarray, image_size: int) -> np.ndarray:
|
489 |
+
"""Crop small image patches from one image.
|
490 |
+
|
491 |
+
Args:
|
492 |
+
image (np.ndarray): The input image for `OpenCV.imread`.
|
493 |
+
image_size (int): The size of the captured image area.
|
494 |
+
|
495 |
+
Returns:
|
496 |
+
patch_image (np.ndarray): Small patch image
|
497 |
+
|
498 |
+
"""
|
499 |
+
image_height, image_width = image.shape[:2]
|
500 |
+
|
501 |
+
# Just need to find the top and left coordinates of the image
|
502 |
+
top = random.randint(0, image_height - image_size)
|
503 |
+
left = random.randint(0, image_width - image_size)
|
504 |
+
|
505 |
+
# Crop image patch
|
506 |
+
patch_image = image[top:top + image_size, left:left + image_size, ...]
|
507 |
+
|
508 |
+
return patch_image
|
509 |
+
|
510 |
+
|
511 |
+
def random_rotate(image,
|
512 |
+
angles: list,
|
513 |
+
center: tuple[int, int] = None,
|
514 |
+
scale_factor: float = 1.0) -> np.ndarray:
|
515 |
+
"""Rotate an image by a random angle
|
516 |
+
|
517 |
+
Args:
|
518 |
+
image (np.ndarray): Image read with OpenCV
|
519 |
+
angles (list): Rotation angle range
|
520 |
+
center (optional, tuple[int, int]): High resolution image selection center point. Default: ``None``
|
521 |
+
scale_factor (optional, float): scaling factor. Default: 1.0
|
522 |
+
|
523 |
+
Returns:
|
524 |
+
rotated_image (np.ndarray): image after rotation
|
525 |
+
|
526 |
+
"""
|
527 |
+
image_height, image_width = image.shape[:2]
|
528 |
+
|
529 |
+
if center is None:
|
530 |
+
center = (image_width // 2, image_height // 2)
|
531 |
+
|
532 |
+
# Random select specific angle
|
533 |
+
angle = random.choice(angles)
|
534 |
+
matrix = cv2.getRotationMatrix2D(center, angle, scale_factor)
|
535 |
+
rotated_image = cv2.warpAffine(image, matrix, (image_width, image_height))
|
536 |
+
|
537 |
+
return rotated_image
|
538 |
+
|
539 |
+
|
540 |
+
def random_horizontally_flip(image: np.ndarray, p: float = 0.5) -> np.ndarray:
|
541 |
+
"""Flip the image upside down randomly
|
542 |
+
|
543 |
+
Args:
|
544 |
+
image (np.ndarray): Image read with OpenCV
|
545 |
+
p (optional, float): Horizontally flip probability. Default: 0.5
|
546 |
+
|
547 |
+
Returns:
|
548 |
+
horizontally_flip_image (np.ndarray): image after horizontally flip
|
549 |
+
|
550 |
+
"""
|
551 |
+
if random.random() < p:
|
552 |
+
horizontally_flip_image = cv2.flip(image, 1)
|
553 |
+
else:
|
554 |
+
horizontally_flip_image = image
|
555 |
+
|
556 |
+
return horizontally_flip_image
|
557 |
+
|
558 |
+
|
559 |
+
def random_vertically_flip(image: np.ndarray, p: float = 0.5) -> np.ndarray:
|
560 |
+
"""Flip an image horizontally randomly
|
561 |
+
|
562 |
+
Args:
|
563 |
+
image (np.ndarray): Image read with OpenCV
|
564 |
+
p (optional, float): Vertically flip probability. Default: 0.5
|
565 |
+
|
566 |
+
Returns:
|
567 |
+
vertically_flip_image (np.ndarray): image after vertically flip
|
568 |
+
|
569 |
+
"""
|
570 |
+
if random.random() < p:
|
571 |
+
vertically_flip_image = cv2.flip(image, 0)
|
572 |
+
else:
|
573 |
+
vertically_flip_image = image
|
574 |
+
|
575 |
+
return vertically_flip_image
|
inference.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7 |
+
#
|
8 |
+
# Unless required by applicable law or agreed to in writing, software
|
9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11 |
+
# See the License for the specific language governing permissions and
|
12 |
+
# limitations under the License.
|
13 |
+
# ==============================================================================
|
14 |
+
import argparse
|
15 |
+
import os
|
16 |
+
|
17 |
+
import cv2
|
18 |
+
import torch
|
19 |
+
from torch import nn
|
20 |
+
|
21 |
+
import imgproc
|
22 |
+
import model
|
23 |
+
from utils import load_state_dict
|
24 |
+
|
25 |
+
model_names = sorted(
|
26 |
+
name for name in model.__dict__ if
|
27 |
+
name.islower() and not name.startswith("__") and callable(model.__dict__[name]))
|
28 |
+
|
29 |
+
|
30 |
+
def choice_device(device_type: str) -> torch.device:
|
31 |
+
# Select model processing equipment type
|
32 |
+
if device_type == "cuda":
|
33 |
+
device = torch.device("cuda", 0)
|
34 |
+
else:
|
35 |
+
device = torch.device("cpu")
|
36 |
+
return device
|
37 |
+
|
38 |
+
|
39 |
+
def build_model(model_arch_name: str, device: torch.device) -> nn.Module:
|
40 |
+
# Initialize the super-resolution model
|
41 |
+
sr_model = model.__dict__[model_arch_name](in_channels=3,
|
42 |
+
out_channels=3,
|
43 |
+
channels=64,
|
44 |
+
num_rcb=16)
|
45 |
+
sr_model = sr_model.to(device=device)
|
46 |
+
|
47 |
+
return sr_model
|
48 |
+
|
49 |
+
|
50 |
+
def main(args):
|
51 |
+
device = choice_device(args.device_type)
|
52 |
+
|
53 |
+
# Initialize the model
|
54 |
+
sr_model = build_model(args.model_arch_name, device)
|
55 |
+
print(f"Build `{args.model_arch_name}` model successfully.")
|
56 |
+
|
57 |
+
# Load model weights
|
58 |
+
sr_model = load_state_dict(sr_model, args.model_weights_path)
|
59 |
+
print(f"Load `{args.model_arch_name}` model weights `{os.path.abspath(args.model_weights_path)}` successfully.")
|
60 |
+
|
61 |
+
# Start the verification mode of the model.
|
62 |
+
sr_model.eval()
|
63 |
+
|
64 |
+
lr_tensor = imgproc.preprocess_one_image(args.inputs_path, device)
|
65 |
+
|
66 |
+
# Use the model to generate super-resolved images
|
67 |
+
with torch.no_grad():
|
68 |
+
sr_tensor = sr_model(lr_tensor)
|
69 |
+
|
70 |
+
# Save image
|
71 |
+
sr_image = imgproc.tensor_to_image(sr_tensor, False, False)
|
72 |
+
sr_image = cv2.cvtColor(sr_image, cv2.COLOR_RGB2BGR)
|
73 |
+
cv2.imwrite(args.output_path, sr_image)
|
74 |
+
|
75 |
+
print(f"SR image save to `{args.output_path}`")
|
76 |
+
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
parser = argparse.ArgumentParser(description="Using the model generator super-resolution images.")
|
80 |
+
parser.add_argument("--model_arch_name",
|
81 |
+
type=str,
|
82 |
+
default="srresnet_x4")
|
83 |
+
parser.add_argument("--inputs_path",
|
84 |
+
type=str,
|
85 |
+
default="./figure/comic_lr.png",
|
86 |
+
help="Low-resolution image path.")
|
87 |
+
parser.add_argument("--output_path",
|
88 |
+
type=str,
|
89 |
+
default="./figure/comic_sr.png",
|
90 |
+
help="Super-resolution image path.")
|
91 |
+
parser.add_argument("--model_weights_path",
|
92 |
+
type=str,
|
93 |
+
default="./results/pretrained_models/SRGAN_x4-ImageNet-8c4a7569.pth.tar",
|
94 |
+
help="Model weights file path.")
|
95 |
+
parser.add_argument("--device_type",
|
96 |
+
type=str,
|
97 |
+
default="cpu",
|
98 |
+
choices=["cpu", "cuda"])
|
99 |
+
args = parser.parse_args()
|
100 |
+
|
101 |
+
main(args)
|
model.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7 |
+
#
|
8 |
+
# Unless required by applicable law or agreed to in writing, software
|
9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11 |
+
# See the License for the specific language governing permissions and
|
12 |
+
# limitations under the License.
|
13 |
+
# ==============================================================================
|
14 |
+
import math
|
15 |
+
from typing import Any
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import Tensor
|
19 |
+
from torch import nn
|
20 |
+
from torch.nn import functional as F_torch
|
21 |
+
from torchvision import models
|
22 |
+
from torchvision import transforms
|
23 |
+
from torchvision.models.feature_extraction import create_feature_extractor
|
24 |
+
|
25 |
+
__all__ = [
|
26 |
+
"SRResNet", "Discriminator",
|
27 |
+
"srresnet_x4", "discriminator", "content_loss",
|
28 |
+
]
|
29 |
+
|
30 |
+
|
31 |
+
class SRResNet(nn.Module):
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
in_channels: int,
|
35 |
+
out_channels: int,
|
36 |
+
channels: int,
|
37 |
+
num_rcb: int,
|
38 |
+
upscale_factor: int
|
39 |
+
) -> None:
|
40 |
+
super(SRResNet, self).__init__()
|
41 |
+
# Low frequency information extraction layer
|
42 |
+
self.conv1 = nn.Sequential(
|
43 |
+
nn.Conv2d(in_channels, channels, (9, 9), (1, 1), (4, 4)),
|
44 |
+
nn.PReLU(),
|
45 |
+
)
|
46 |
+
|
47 |
+
# High frequency information extraction block
|
48 |
+
trunk = []
|
49 |
+
for _ in range(num_rcb):
|
50 |
+
trunk.append(_ResidualConvBlock(channels))
|
51 |
+
self.trunk = nn.Sequential(*trunk)
|
52 |
+
|
53 |
+
# High-frequency information linear fusion layer
|
54 |
+
self.conv2 = nn.Sequential(
|
55 |
+
nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False),
|
56 |
+
nn.BatchNorm2d(channels),
|
57 |
+
)
|
58 |
+
|
59 |
+
# zoom block
|
60 |
+
upsampling = []
|
61 |
+
if upscale_factor == 2 or upscale_factor == 4 or upscale_factor == 8:
|
62 |
+
for _ in range(int(math.log(upscale_factor, 2))):
|
63 |
+
upsampling.append(_UpsampleBlock(channels, 2))
|
64 |
+
elif upscale_factor == 3:
|
65 |
+
upsampling.append(_UpsampleBlock(channels, 3))
|
66 |
+
self.upsampling = nn.Sequential(*upsampling)
|
67 |
+
|
68 |
+
# reconstruction block
|
69 |
+
self.conv3 = nn.Conv2d(channels, out_channels, (9, 9), (1, 1), (4, 4))
|
70 |
+
|
71 |
+
# Initialize neural network weights
|
72 |
+
self._initialize_weights()
|
73 |
+
|
74 |
+
def forward(self, x: Tensor) -> Tensor:
|
75 |
+
return self._forward_impl(x)
|
76 |
+
|
77 |
+
# Support torch.script function
|
78 |
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
79 |
+
out1 = self.conv1(x)
|
80 |
+
out = self.trunk(out1)
|
81 |
+
out2 = self.conv2(out)
|
82 |
+
out = torch.add(out1, out2)
|
83 |
+
out = self.upsampling(out)
|
84 |
+
out = self.conv3(out)
|
85 |
+
|
86 |
+
out = torch.clamp_(out, 0.0, 1.0)
|
87 |
+
|
88 |
+
return out
|
89 |
+
|
90 |
+
def _initialize_weights(self) -> None:
|
91 |
+
for module in self.modules():
|
92 |
+
if isinstance(module, nn.Conv2d):
|
93 |
+
nn.init.kaiming_normal_(module.weight)
|
94 |
+
if module.bias is not None:
|
95 |
+
nn.init.constant_(module.bias, 0)
|
96 |
+
elif isinstance(module, nn.BatchNorm2d):
|
97 |
+
nn.init.constant_(module.weight, 1)
|
98 |
+
|
99 |
+
|
100 |
+
class Discriminator(nn.Module):
|
101 |
+
def __init__(self) -> None:
|
102 |
+
super(Discriminator, self).__init__()
|
103 |
+
self.features = nn.Sequential(
|
104 |
+
# input size. (3) x 96 x 96
|
105 |
+
nn.Conv2d(3, 64, (3, 3), (1, 1), (1, 1), bias=True),
|
106 |
+
nn.LeakyReLU(0.2, True),
|
107 |
+
# state size. (64) x 48 x 48
|
108 |
+
nn.Conv2d(64, 64, (3, 3), (2, 2), (1, 1), bias=False),
|
109 |
+
nn.BatchNorm2d(64),
|
110 |
+
nn.LeakyReLU(0.2, True),
|
111 |
+
nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1), bias=False),
|
112 |
+
nn.BatchNorm2d(128),
|
113 |
+
nn.LeakyReLU(0.2, True),
|
114 |
+
# state size. (128) x 24 x 24
|
115 |
+
nn.Conv2d(128, 128, (3, 3), (2, 2), (1, 1), bias=False),
|
116 |
+
nn.BatchNorm2d(128),
|
117 |
+
nn.LeakyReLU(0.2, True),
|
118 |
+
nn.Conv2d(128, 256, (3, 3), (1, 1), (1, 1), bias=False),
|
119 |
+
nn.BatchNorm2d(256),
|
120 |
+
nn.LeakyReLU(0.2, True),
|
121 |
+
# state size. (256) x 12 x 12
|
122 |
+
nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1), bias=False),
|
123 |
+
nn.BatchNorm2d(256),
|
124 |
+
nn.LeakyReLU(0.2, True),
|
125 |
+
nn.Conv2d(256, 512, (3, 3), (1, 1), (1, 1), bias=False),
|
126 |
+
nn.BatchNorm2d(512),
|
127 |
+
nn.LeakyReLU(0.2, True),
|
128 |
+
# state size. (512) x 6 x 6
|
129 |
+
nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), bias=False),
|
130 |
+
nn.BatchNorm2d(512),
|
131 |
+
nn.LeakyReLU(0.2, True),
|
132 |
+
)
|
133 |
+
|
134 |
+
self.classifier = nn.Sequential(
|
135 |
+
nn.Linear(512 * 6 * 6, 1024),
|
136 |
+
nn.LeakyReLU(0.2, True),
|
137 |
+
nn.Linear(1024, 1),
|
138 |
+
)
|
139 |
+
|
140 |
+
def forward(self, x: Tensor) -> Tensor:
|
141 |
+
# Input image size must equal 96
|
142 |
+
assert x.shape[2] == 96 and x.shape[3] == 96, "Image shape must equal 96x96"
|
143 |
+
|
144 |
+
out = self.features(x)
|
145 |
+
out = torch.flatten(out, 1)
|
146 |
+
out = self.classifier(out)
|
147 |
+
|
148 |
+
return out
|
149 |
+
|
150 |
+
|
151 |
+
class _ResidualConvBlock(nn.Module):
|
152 |
+
def __init__(self, channels: int) -> None:
|
153 |
+
super(_ResidualConvBlock, self).__init__()
|
154 |
+
self.rcb = nn.Sequential(
|
155 |
+
nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False),
|
156 |
+
nn.BatchNorm2d(channels),
|
157 |
+
nn.PReLU(),
|
158 |
+
nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False),
|
159 |
+
nn.BatchNorm2d(channels),
|
160 |
+
)
|
161 |
+
|
162 |
+
def forward(self, x: Tensor) -> Tensor:
|
163 |
+
identity = x
|
164 |
+
|
165 |
+
out = self.rcb(x)
|
166 |
+
|
167 |
+
out = torch.add(out, identity)
|
168 |
+
|
169 |
+
return out
|
170 |
+
|
171 |
+
|
172 |
+
class _UpsampleBlock(nn.Module):
|
173 |
+
def __init__(self, channels: int, upscale_factor: int) -> None:
|
174 |
+
super(_UpsampleBlock, self).__init__()
|
175 |
+
self.upsample_block = nn.Sequential(
|
176 |
+
nn.Conv2d(channels, channels * upscale_factor * upscale_factor, (3, 3), (1, 1), (1, 1)),
|
177 |
+
nn.PixelShuffle(2),
|
178 |
+
nn.PReLU(),
|
179 |
+
)
|
180 |
+
|
181 |
+
def forward(self, x: Tensor) -> Tensor:
|
182 |
+
out = self.upsample_block(x)
|
183 |
+
|
184 |
+
return out
|
185 |
+
|
186 |
+
|
187 |
+
class _ContentLoss(nn.Module):
|
188 |
+
"""Constructs a content loss function based on the VGG19 network.
|
189 |
+
Using high-level feature mapping layers from the latter layers will focus more on the texture content of the image.
|
190 |
+
|
191 |
+
Paper reference list:
|
192 |
+
-`Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network <https://arxiv.org/pdf/1609.04802.pdf>` paper.
|
193 |
+
-`ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks <https://arxiv.org/pdf/1809.00219.pdf>` paper.
|
194 |
+
-`Perceptual Extreme Super Resolution Network with Receptive Field Block <https://arxiv.org/pdf/2005.12597.pdf>` paper.
|
195 |
+
|
196 |
+
"""
|
197 |
+
|
198 |
+
def __init__(
|
199 |
+
self,
|
200 |
+
feature_model_extractor_node: str,
|
201 |
+
feature_model_normalize_mean: list,
|
202 |
+
feature_model_normalize_std: list
|
203 |
+
) -> None:
|
204 |
+
super(_ContentLoss, self).__init__()
|
205 |
+
# Get the name of the specified feature extraction node
|
206 |
+
self.feature_model_extractor_node = feature_model_extractor_node
|
207 |
+
# Load the VGG19 model trained on the ImageNet dataset.
|
208 |
+
model = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1)
|
209 |
+
# Extract the thirty-sixth layer output in the VGG19 model as the content loss.
|
210 |
+
self.feature_extractor = create_feature_extractor(model, [feature_model_extractor_node])
|
211 |
+
# set to validation mode
|
212 |
+
self.feature_extractor.eval()
|
213 |
+
|
214 |
+
# The preprocessing method of the input data.
|
215 |
+
# This is the VGG model preprocessing method of the ImageNet dataset.
|
216 |
+
self.normalize = transforms.Normalize(feature_model_normalize_mean, feature_model_normalize_std)
|
217 |
+
|
218 |
+
# Freeze model parameters.
|
219 |
+
for model_parameters in self.feature_extractor.parameters():
|
220 |
+
model_parameters.requires_grad = False
|
221 |
+
|
222 |
+
def forward(self, sr_tensor: Tensor, gt_tensor: Tensor) -> Tensor:
|
223 |
+
# Standardized operations
|
224 |
+
sr_tensor = self.normalize(sr_tensor)
|
225 |
+
gt_tensor = self.normalize(gt_tensor)
|
226 |
+
|
227 |
+
sr_feature = self.feature_extractor(sr_tensor)[self.feature_model_extractor_node]
|
228 |
+
gt_feature = self.feature_extractor(gt_tensor)[self.feature_model_extractor_node]
|
229 |
+
|
230 |
+
# Find the feature map difference between the two images
|
231 |
+
loss = F_torch.mse_loss(sr_feature, gt_feature)
|
232 |
+
|
233 |
+
return loss
|
234 |
+
|
235 |
+
|
236 |
+
def srresnet_x4(**kwargs: Any) -> SRResNet:
|
237 |
+
model = SRResNet(upscale_factor=4, **kwargs)
|
238 |
+
|
239 |
+
return model
|
240 |
+
|
241 |
+
|
242 |
+
def discriminator() -> Discriminator:
|
243 |
+
model = Discriminator()
|
244 |
+
|
245 |
+
return model
|
246 |
+
|
247 |
+
|
248 |
+
def content_loss(**kwargs: Any) -> _ContentLoss:
|
249 |
+
content_loss = _ContentLoss(**kwargs)
|
250 |
+
|
251 |
+
return content_loss
|
utils.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7 |
+
#
|
8 |
+
# Unless required by applicable law or agreed to in writing, software
|
9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11 |
+
# See the License for the specific language governing permissions and
|
12 |
+
# limitations under the License.
|
13 |
+
# ==============================================================================
|
14 |
+
import os
|
15 |
+
import shutil
|
16 |
+
from enum import Enum
|
17 |
+
from typing import Any
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from torch import nn
|
21 |
+
from torch.nn import Module
|
22 |
+
from torch.optim import Optimizer
|
23 |
+
|
24 |
+
__all__ = [
|
25 |
+
"load_state_dict", "make_directory", "save_checkpoint",
|
26 |
+
"Summary", "AverageMeter", "ProgressMeter"
|
27 |
+
]
|
28 |
+
|
29 |
+
|
30 |
+
def load_state_dict(
|
31 |
+
model: nn.Module,
|
32 |
+
model_weights_path: str,
|
33 |
+
ema_model: nn.Module = None,
|
34 |
+
optimizer: torch.optim.Optimizer = None,
|
35 |
+
scheduler: torch.optim.lr_scheduler = None,
|
36 |
+
load_mode: str = None,
|
37 |
+
) -> tuple[Module, Module, Any, Any, Any, Optimizer | None, Any] | tuple[Module, Any, Any, Any, Optimizer | None, Any] | Module:
|
38 |
+
# Load model weights
|
39 |
+
checkpoint = torch.load(model_weights_path, map_location=lambda storage, loc: storage)
|
40 |
+
|
41 |
+
if load_mode == "resume":
|
42 |
+
# Restore the parameters in the training node to this point
|
43 |
+
start_epoch = checkpoint["epoch"]
|
44 |
+
best_psnr = checkpoint["best_psnr"]
|
45 |
+
best_ssim = checkpoint["best_ssim"]
|
46 |
+
# Load model state dict. Extract the fitted model weights
|
47 |
+
model_state_dict = model.state_dict()
|
48 |
+
state_dict = {k: v for k, v in checkpoint["state_dict"].items() if k in model_state_dict.keys()}
|
49 |
+
# Overwrite the model weights to the current model (base model)
|
50 |
+
model_state_dict.update(state_dict)
|
51 |
+
model.load_state_dict(model_state_dict)
|
52 |
+
# Load the optimizer model
|
53 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
54 |
+
|
55 |
+
if scheduler is not None:
|
56 |
+
# Load the scheduler model
|
57 |
+
scheduler.load_state_dict(checkpoint["scheduler"])
|
58 |
+
|
59 |
+
if ema_model is not None:
|
60 |
+
# Load ema model state dict. Extract the fitted model weights
|
61 |
+
ema_model_state_dict = ema_model.state_dict()
|
62 |
+
ema_state_dict = {k: v for k, v in checkpoint["ema_state_dict"].items() if k in ema_model_state_dict.keys()}
|
63 |
+
# Overwrite the model weights to the current model (ema model)
|
64 |
+
ema_model_state_dict.update(ema_state_dict)
|
65 |
+
ema_model.load_state_dict(ema_model_state_dict)
|
66 |
+
|
67 |
+
return model, ema_model, start_epoch, best_psnr, best_ssim, optimizer, scheduler
|
68 |
+
else:
|
69 |
+
# Load model state dict. Extract the fitted model weights
|
70 |
+
model_state_dict = model.state_dict()
|
71 |
+
state_dict = {k: v for k, v in checkpoint["state_dict"].items() if
|
72 |
+
k in model_state_dict.keys() and v.size() == model_state_dict[k].size()}
|
73 |
+
# Overwrite the model weights to the current model
|
74 |
+
model_state_dict.update(state_dict)
|
75 |
+
model.load_state_dict(model_state_dict)
|
76 |
+
|
77 |
+
return model
|
78 |
+
|
79 |
+
|
80 |
+
def make_directory(dir_path: str) -> None:
|
81 |
+
if not os.path.exists(dir_path):
|
82 |
+
os.makedirs(dir_path)
|
83 |
+
|
84 |
+
|
85 |
+
def save_checkpoint(
|
86 |
+
state_dict: dict,
|
87 |
+
file_name: str,
|
88 |
+
samples_dir: str,
|
89 |
+
results_dir: str,
|
90 |
+
best_file_name: str,
|
91 |
+
last_file_name: str,
|
92 |
+
is_best: bool = False,
|
93 |
+
is_last: bool = False,
|
94 |
+
) -> None:
|
95 |
+
checkpoint_path = os.path.join(samples_dir, file_name)
|
96 |
+
torch.save(state_dict, checkpoint_path)
|
97 |
+
|
98 |
+
if is_best:
|
99 |
+
shutil.copyfile(checkpoint_path, os.path.join(results_dir, best_file_name))
|
100 |
+
if is_last:
|
101 |
+
shutil.copyfile(checkpoint_path, os.path.join(results_dir, last_file_name))
|
102 |
+
|
103 |
+
|
104 |
+
class Summary(Enum):
|
105 |
+
NONE = 0
|
106 |
+
AVERAGE = 1
|
107 |
+
SUM = 2
|
108 |
+
COUNT = 3
|
109 |
+
|
110 |
+
|
111 |
+
class AverageMeter(object):
|
112 |
+
def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE):
|
113 |
+
self.name = name
|
114 |
+
self.fmt = fmt
|
115 |
+
self.summary_type = summary_type
|
116 |
+
self.reset()
|
117 |
+
|
118 |
+
def reset(self):
|
119 |
+
self.val = 0
|
120 |
+
self.avg = 0
|
121 |
+
self.sum = 0
|
122 |
+
self.count = 0
|
123 |
+
|
124 |
+
def update(self, val, n=1):
|
125 |
+
self.val = val
|
126 |
+
self.sum += val * n
|
127 |
+
self.count += n
|
128 |
+
self.avg = self.sum / self.count
|
129 |
+
|
130 |
+
def __str__(self):
|
131 |
+
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
|
132 |
+
return fmtstr.format(**self.__dict__)
|
133 |
+
|
134 |
+
def summary(self):
|
135 |
+
if self.summary_type is Summary.NONE:
|
136 |
+
fmtstr = ""
|
137 |
+
elif self.summary_type is Summary.AVERAGE:
|
138 |
+
fmtstr = "{name} {avg:.2f}"
|
139 |
+
elif self.summary_type is Summary.SUM:
|
140 |
+
fmtstr = "{name} {sum:.2f}"
|
141 |
+
elif self.summary_type is Summary.COUNT:
|
142 |
+
fmtstr = "{name} {count:.2f}"
|
143 |
+
else:
|
144 |
+
raise ValueError(f"Invalid summary type {self.summary_type}")
|
145 |
+
|
146 |
+
return fmtstr.format(**self.__dict__)
|
147 |
+
|
148 |
+
|
149 |
+
class ProgressMeter(object):
|
150 |
+
def __init__(self, num_batches, meters, prefix=""):
|
151 |
+
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
152 |
+
self.meters = meters
|
153 |
+
self.prefix = prefix
|
154 |
+
|
155 |
+
def display(self, batch):
|
156 |
+
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
157 |
+
entries += [str(meter) for meter in self.meters]
|
158 |
+
print("\t".join(entries))
|
159 |
+
|
160 |
+
def display_summary(self):
|
161 |
+
entries = [" *"]
|
162 |
+
entries += [meter.summary() for meter in self.meters]
|
163 |
+
print(" ".join(entries))
|
164 |
+
|
165 |
+
def _get_batch_fmtstr(self, num_batches):
|
166 |
+
num_digits = len(str(num_batches // 1))
|
167 |
+
fmt = "{:" + str(num_digits) + "d}"
|
168 |
+
return "[" + fmt + "/" + fmt.format(num_batches) + "]"
|
weights/SRGAN_x4-ImageNet-8c4a7569.pth.tar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7c5431b5921e1509190aed6aca02c7d5838f4805e8ae6f9fa08c140260b6a2a3
|
3 |
+
size 6285796
|