Thibaud Cheruy commited on
Commit
92d45d2
·
1 Parent(s): 56a1b0b

New: Add SRGAN Space

Browse files
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.pth.tar filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ venv/
2
+ .idea/
README.md CHANGED
@@ -8,6 +8,7 @@ sdk_version: 3.16.1
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ python_version: 3.10.3
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
__pycache__/imgproc.cpython-310.pyc ADDED
Binary file (13.5 kB). View file
 
__pycache__/inference.cpython-310.pyc ADDED
Binary file (2.47 kB). View file
 
__pycache__/model.cpython-310.pyc ADDED
Binary file (6.9 kB). View file
 
__pycache__/utils.cpython-310.pyc ADDED
Binary file (5.46 kB). View file
 
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import cv2
4
+
5
+ import imgproc
6
+ from imgproc import image_to_tensor
7
+ from inference import choice_device, build_model
8
+ from utils import load_state_dict
9
+
10
+ model = "srresnet_x4"
11
+
12
+ device = choice_device("cpu")
13
+
14
+ # Initialize the model
15
+ sr_model = build_model(model, device)
16
+ print(f"Build {model} model successfully.")
17
+
18
+ # Load model weights
19
+ sr_model = load_state_dict(sr_model, "weights/SRGAN_x4-ImageNet-8c4a7569.pth.tar")
20
+ print(f"Load `{model}` model weights successfully.")
21
+
22
+ # Start the verification mode of the model.
23
+ sr_model.eval()
24
+
25
+ def downscale(image):
26
+ (width, height, colors) = image.shape
27
+
28
+ new_height = int(60 * width / height)
29
+
30
+ return cv2.resize(image, (60, new_height), interpolation=cv2.INTER_AREA)
31
+
32
+
33
+ def preprocess(image):
34
+ image = image / 255.0
35
+
36
+ # Convert image data to pytorch format data
37
+ tensor = image_to_tensor(image, False, False).unsqueeze_(0)
38
+
39
+ # Transfer tensor channel image format data to CUDA device
40
+ tensor = tensor.to(device="cpu", memory_format=torch.channels_last, non_blocking=True)
41
+
42
+ return tensor
43
+
44
+ def processHighRes(image):
45
+ if image is None:
46
+ raise gr.Error("Please enter an image")
47
+ downscaled = downscale(image)
48
+ lr_tensor = preprocess(downscaled)
49
+
50
+ # Use the model to generate super-resolved images
51
+ with torch.no_grad():
52
+ sr_tensor = sr_model(lr_tensor)
53
+
54
+ # Save image
55
+ sr_image = imgproc.tensor_to_image(sr_tensor, False, False)
56
+
57
+ return [downscaled, sr_image]
58
+
59
+ def processLowRes(image):
60
+ if image is None:
61
+ raise gr.Error("Please enter an image")
62
+
63
+ (width, height, colors) = image.shape
64
+
65
+ if width > 150 or height > 150:
66
+ raise gr.Error("Image is too big")
67
+
68
+ lr_tensor = preprocess(image)
69
+
70
+ # Use the model to generate super-resolved images
71
+ with torch.no_grad():
72
+ sr_tensor = sr_model(lr_tensor)
73
+
74
+ # Save image
75
+ sr_image = imgproc.tensor_to_image(sr_tensor, False, False)
76
+
77
+ return [sr_image]
78
+
79
+ description = """<p style='text-align: center'> <a href='https://arxiv.org/abs/1609.04802' target='_blank'>Paper</a> | <a href=https://github.com/Lornatang/SRGAN-PyTorch target='_blank'>GitHub</a></p>"""
80
+
81
+ with gr.Blocks() as demo:
82
+ gr.Markdown("# **<p align='center'>SRGAN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network</p>**")
83
+ gr.Markdown(description)
84
+
85
+ with gr.Tab("From high res"):
86
+ high_res_input = gr.Image(label="High-res source image", show_label=True)
87
+ with gr.Row():
88
+ low_res_output = gr.Image(label="Low-res image")
89
+ srgan_output = gr.Image(label="SRGAN Output")
90
+ high_res_button = gr.Button("Process")
91
+
92
+ with gr.Tab("From low res"):
93
+ low_res_input = gr.Image(label="Low-res source image", show_label=True)
94
+ srgan_upscale = gr.Image(label="SRGAN Output")
95
+ low_res_button = gr.Button("Process")
96
+
97
+ gr.Examples(
98
+ examples=["examples/bird.png", "examples/butterfly.png", "examples/comic.png", "examples/gray.png",
99
+ "examples/man.png"],
100
+ inputs=[high_res_input],
101
+ outputs=[low_res_output, srgan_output],
102
+ fn=processHighRes
103
+ )
104
+
105
+ high_res_button.click(processHighRes, inputs=[high_res_input], outputs=[low_res_output, srgan_output])
106
+ low_res_button.click(processLowRes, inputs=[low_res_input], outputs=[srgan_upscale])
107
+
108
+ gr.Markdown("<p style='text-align: center'>Made for the 2022-2023 Grenoble-INP Phelma Image analysis course, by Thibaud CHERUY, Clément DEBUY & Yassine El Khanoussi.</p>")
109
+
110
+
111
+ demo.launch()
examples/bird.png ADDED
examples/butterfly.png ADDED
examples/comic.png ADDED
examples/gray.png ADDED
examples/man.png ADDED
imgproc.py ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ import math
15
+ import random
16
+ from typing import Any
17
+
18
+ import cv2
19
+ import numpy as np
20
+ import torch
21
+ from numpy import ndarray
22
+ from torch import Tensor
23
+
24
+ __all__ = [
25
+ "image_to_tensor", "tensor_to_image",
26
+ "image_resize", "preprocess_one_image",
27
+ "expand_y", "rgb_to_ycbcr", "bgr_to_ycbcr", "ycbcr_to_bgr", "ycbcr_to_rgb",
28
+ "rgb_to_ycbcr_torch", "bgr_to_ycbcr_torch",
29
+ "center_crop", "random_crop", "random_rotate", "random_vertically_flip", "random_horizontally_flip",
30
+ ]
31
+
32
+
33
+ # Code reference `https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/matlab_functions.py`
34
+ def _cubic(x: Any) -> Any:
35
+ """Implementation of `cubic` function in Matlab under Python language.
36
+
37
+ Args:
38
+ x: Element vector.
39
+
40
+ Returns:
41
+ Bicubic interpolation
42
+
43
+ """
44
+ absx = torch.abs(x)
45
+ absx2 = absx ** 2
46
+ absx3 = absx ** 3
47
+ return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + (
48
+ -0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (
49
+ ((absx > 1) * (absx <= 2)).type_as(absx))
50
+
51
+
52
+ # Code reference `https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/matlab_functions.py`
53
+ def _calculate_weights_indices(in_length: int,
54
+ out_length: int,
55
+ scale: float,
56
+ kernel_width: int,
57
+ antialiasing: bool) -> [np.ndarray, np.ndarray, int, int]:
58
+ """Implementation of `calculate_weights_indices` function in Matlab under Python language.
59
+
60
+ Args:
61
+ in_length (int): Input length.
62
+ out_length (int): Output length.
63
+ scale (float): Scale factor.
64
+ kernel_width (int): Kernel width.
65
+ antialiasing (bool): Whether to apply antialiasing when down-sampling operations.
66
+ Caution: Bicubic down-sampling in PIL uses antialiasing by default.
67
+
68
+ Returns:
69
+ weights, indices, sym_len_s, sym_len_e
70
+
71
+ """
72
+ if (scale < 1) and antialiasing:
73
+ # Use a modified kernel (larger kernel width) to simultaneously
74
+ # interpolate and antialiasing
75
+ kernel_width = kernel_width / scale
76
+
77
+ # Output-space coordinates
78
+ x = torch.linspace(1, out_length, out_length)
79
+
80
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
81
+ # in output space maps to 0.5 in input space, and 0.5 + scale in output
82
+ # space maps to 1.5 in input space.
83
+ u = x / scale + 0.5 * (1 - 1 / scale)
84
+
85
+ # What is the left-most pixel that can be involved in the computation?
86
+ left = torch.floor(u - kernel_width / 2)
87
+
88
+ # What is the maximum number of pixels that can be involved in the
89
+ # computation? Note: it's OK to use an extra pixel here; if the
90
+ # corresponding weights are all zero, it will be eliminated at the end
91
+ # of this function.
92
+ p = math.ceil(kernel_width) + 2
93
+
94
+ # The indices of the input pixels involved in computing the k-th output
95
+ # pixel are in row k of the indices matrix.
96
+ indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
97
+ out_length, p)
98
+
99
+ # The weights used to compute the k-th output pixel are in row k of the
100
+ # weights matrix.
101
+ distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
102
+
103
+ # apply cubic kernel
104
+ if (scale < 1) and antialiasing:
105
+ weights = scale * _cubic(distance_to_center * scale)
106
+ else:
107
+ weights = _cubic(distance_to_center)
108
+
109
+ # Normalize the weights matrix so that each row sums to 1.
110
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
111
+ weights = weights / weights_sum.expand(out_length, p)
112
+
113
+ # If a column in weights is all zero, get rid of it. only consider the
114
+ # first and last column.
115
+ weights_zero_tmp = torch.sum((weights == 0), 0)
116
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
117
+ indices = indices.narrow(1, 1, p - 2)
118
+ weights = weights.narrow(1, 1, p - 2)
119
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
120
+ indices = indices.narrow(1, 0, p - 2)
121
+ weights = weights.narrow(1, 0, p - 2)
122
+ weights = weights.contiguous()
123
+ indices = indices.contiguous()
124
+ sym_len_s = -indices.min() + 1
125
+ sym_len_e = indices.max() - in_length
126
+ indices = indices + sym_len_s - 1
127
+ return weights, indices, int(sym_len_s), int(sym_len_e)
128
+
129
+
130
+ def image_to_tensor(image: ndarray, range_norm: bool, half: bool) -> Tensor:
131
+ """Convert the image data type to the Tensor (NCWH) data type supported by PyTorch
132
+
133
+ Args:
134
+ image (np.ndarray): The image data read by ``OpenCV.imread``, the data range is [0,255] or [0, 1]
135
+ range_norm (bool): Scale [0, 1] data to between [-1, 1]
136
+ half (bool): Whether to convert torch.float32 similarly to torch.half type
137
+
138
+ Returns:
139
+ tensor (Tensor): Data types supported by PyTorch
140
+
141
+ Examples:
142
+ >>> example_image = cv2.imread("lr_image.bmp")
143
+ >>> example_tensor = image_to_tensor(example_image, range_norm=True, half=False)
144
+
145
+ """
146
+ # Convert image data type to Tensor data type
147
+ tensor = torch.from_numpy(np.ascontiguousarray(image)).permute(2, 0, 1).float()
148
+
149
+ # Scale the image data from [0, 1] to [-1, 1]
150
+ if range_norm:
151
+ tensor = tensor.mul(2.0).sub(1.0)
152
+
153
+ # Convert torch.float32 image data type to torch.half image data type
154
+ if half:
155
+ tensor = tensor.half()
156
+
157
+ return tensor
158
+
159
+
160
+ def tensor_to_image(tensor: Tensor, range_norm: bool, half: bool) -> Any:
161
+ """Convert the Tensor(NCWH) data type supported by PyTorch to the np.ndarray(WHC) image data type
162
+
163
+ Args:
164
+ tensor (Tensor): Data types supported by PyTorch (NCHW), the data range is [0, 1]
165
+ range_norm (bool): Scale [-1, 1] data to between [0, 1]
166
+ half (bool): Whether to convert torch.float32 similarly to torch.half type.
167
+
168
+ Returns:
169
+ image (np.ndarray): Data types supported by PIL or OpenCV
170
+
171
+ Examples:
172
+ >>> example_image = cv2.imread("lr_image.bmp")
173
+ >>> example_tensor = image_to_tensor(example_image, range_norm=False, half=False)
174
+
175
+ """
176
+ if range_norm:
177
+ tensor = tensor.add(1.0).div(2.0)
178
+ if half:
179
+ tensor = tensor.half()
180
+
181
+ image = tensor.squeeze(0).permute(1, 2, 0).mul(255).clamp(0, 255).cpu().numpy().astype("uint8")
182
+
183
+ return image
184
+
185
+
186
+ def preprocess_one_image(image_path: str, device: torch.device) -> Tensor:
187
+ image = cv2.imread(image_path).astype(np.float32) / 255.0
188
+
189
+ # BGR to RGB
190
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
191
+
192
+ # Convert image data to pytorch format data
193
+ tensor = image_to_tensor(image, False, False).unsqueeze_(0)
194
+
195
+ # Transfer tensor channel image format data to CUDA device
196
+ tensor = tensor.to(device=device, memory_format=torch.channels_last, non_blocking=True)
197
+
198
+ return tensor
199
+
200
+
201
+ # Code reference `https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/matlab_functions.py`
202
+ def image_resize(image: Any, scale_factor: float, antialiasing: bool = True) -> Any:
203
+ """Implementation of `imresize` function in Matlab under Python language.
204
+
205
+ Args:
206
+ image: The input image.
207
+ scale_factor (float): Scale factor. The same scale applies for both height and width.
208
+ antialiasing (bool): Whether to apply antialiasing when down-sampling operations.
209
+ Caution: Bicubic down-sampling in `PIL` uses antialiasing by default. Default: ``True``.
210
+
211
+ Returns:
212
+ out_2 (np.ndarray): Output image with shape (c, h, w), [0, 1] range, w/o round
213
+
214
+ """
215
+ squeeze_flag = False
216
+ if type(image).__module__ == np.__name__: # numpy type
217
+ numpy_type = True
218
+ if image.ndim == 2:
219
+ image = image[:, :, None]
220
+ squeeze_flag = True
221
+ image = torch.from_numpy(image.transpose(2, 0, 1)).float()
222
+ else:
223
+ numpy_type = False
224
+ if image.ndim == 2:
225
+ image = image.unsqueeze(0)
226
+ squeeze_flag = True
227
+
228
+ in_c, in_h, in_w = image.size()
229
+ out_h, out_w = math.ceil(in_h * scale_factor), math.ceil(in_w * scale_factor)
230
+ kernel_width = 4
231
+
232
+ # get weights and indices
233
+ weights_h, indices_h, sym_len_hs, sym_len_he = _calculate_weights_indices(in_h, out_h, scale_factor, kernel_width,
234
+ antialiasing)
235
+ weights_w, indices_w, sym_len_ws, sym_len_we = _calculate_weights_indices(in_w, out_w, scale_factor, kernel_width,
236
+ antialiasing)
237
+ # process H dimension
238
+ # symmetric copying
239
+ img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
240
+ img_aug.narrow(1, sym_len_hs, in_h).copy_(image)
241
+
242
+ sym_patch = image[:, :sym_len_hs, :]
243
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
244
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
245
+ img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
246
+
247
+ sym_patch = image[:, -sym_len_he:, :]
248
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
249
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
250
+ img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
251
+
252
+ out_1 = torch.FloatTensor(in_c, out_h, in_w)
253
+ kernel_width = weights_h.size(1)
254
+ for i in range(out_h):
255
+ idx = int(indices_h[i][0])
256
+ for j in range(in_c):
257
+ out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
258
+
259
+ # process W dimension
260
+ # symmetric copying
261
+ out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
262
+ out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
263
+
264
+ sym_patch = out_1[:, :, :sym_len_ws]
265
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
266
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
267
+ out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
268
+
269
+ sym_patch = out_1[:, :, -sym_len_we:]
270
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
271
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
272
+ out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
273
+
274
+ out_2 = torch.FloatTensor(in_c, out_h, out_w)
275
+ kernel_width = weights_w.size(1)
276
+ for i in range(out_w):
277
+ idx = int(indices_w[i][0])
278
+ for j in range(in_c):
279
+ out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
280
+
281
+ if squeeze_flag:
282
+ out_2 = out_2.squeeze(0)
283
+ if numpy_type:
284
+ out_2 = out_2.numpy()
285
+ if not squeeze_flag:
286
+ out_2 = out_2.transpose(1, 2, 0)
287
+
288
+ return out_2
289
+
290
+
291
+ def expand_y(image: np.ndarray) -> np.ndarray:
292
+ """Convert BGR channel to YCbCr format,
293
+ and expand Y channel data in YCbCr, from HW to HWC
294
+
295
+ Args:
296
+ image (np.ndarray): Y channel image data
297
+
298
+ Returns:
299
+ y_image (np.ndarray): Y-channel image data in HWC form
300
+
301
+ """
302
+ # Normalize image data to [0, 1]
303
+ image = image.astype(np.float32) / 255.
304
+
305
+ # Convert BGR to YCbCr, and extract only Y channel
306
+ y_image = bgr_to_ycbcr(image, only_use_y_channel=True)
307
+
308
+ # Expand Y channel
309
+ y_image = y_image[..., None]
310
+
311
+ # Normalize the image data to [0, 255]
312
+ y_image = y_image.astype(np.float64) * 255.0
313
+
314
+ return y_image
315
+
316
+
317
+ def rgb_to_ycbcr(image: np.ndarray, only_use_y_channel: bool) -> np.ndarray:
318
+ """Implementation of rgb2ycbcr function in Matlab under Python language
319
+
320
+ Args:
321
+ image (np.ndarray): Image input in RGB format.
322
+ only_use_y_channel (bool): Extract Y channel separately
323
+
324
+ Returns:
325
+ image (np.ndarray): YCbCr image array data
326
+
327
+ """
328
+ if only_use_y_channel:
329
+ image = np.dot(image, [65.481, 128.553, 24.966]) + 16.0
330
+ else:
331
+ image = np.matmul(image, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [
332
+ 16, 128, 128]
333
+
334
+ image /= 255.
335
+ image = image.astype(np.float32)
336
+
337
+ return image
338
+
339
+
340
+ def bgr_to_ycbcr(image: np.ndarray, only_use_y_channel: bool) -> np.ndarray:
341
+ """Implementation of bgr2ycbcr function in Matlab under Python language.
342
+
343
+ Args:
344
+ image (np.ndarray): Image input in BGR format
345
+ only_use_y_channel (bool): Extract Y channel separately
346
+
347
+ Returns:
348
+ image (np.ndarray): YCbCr image array data
349
+
350
+ """
351
+ if only_use_y_channel:
352
+ image = np.dot(image, [24.966, 128.553, 65.481]) + 16.0
353
+ else:
354
+ image = np.matmul(image, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [
355
+ 16, 128, 128]
356
+
357
+ image /= 255.
358
+ image = image.astype(np.float32)
359
+
360
+ return image
361
+
362
+
363
+ def ycbcr_to_rgb(image: np.ndarray) -> np.ndarray:
364
+ """Implementation of ycbcr2rgb function in Matlab under Python language.
365
+
366
+ Args:
367
+ image (np.ndarray): Image input in YCbCr format.
368
+
369
+ Returns:
370
+ image (np.ndarray): RGB image array data
371
+
372
+ """
373
+ image_dtype = image.dtype
374
+ image *= 255.
375
+
376
+ image = np.matmul(image, [[0.00456621, 0.00456621, 0.00456621],
377
+ [0, -0.00153632, 0.00791071],
378
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
379
+
380
+ image /= 255.
381
+ image = image.astype(image_dtype)
382
+
383
+ return image
384
+
385
+
386
+ def ycbcr_to_bgr(image: np.ndarray) -> np.ndarray:
387
+ """Implementation of ycbcr2bgr function in Matlab under Python language.
388
+
389
+ Args:
390
+ image (np.ndarray): Image input in YCbCr format.
391
+
392
+ Returns:
393
+ image (np.ndarray): BGR image array data
394
+
395
+ """
396
+ image_dtype = image.dtype
397
+ image *= 255.
398
+
399
+ image = np.matmul(image, [[0.00456621, 0.00456621, 0.00456621],
400
+ [0.00791071, -0.00153632, 0],
401
+ [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921]
402
+
403
+ image /= 255.
404
+ image = image.astype(image_dtype)
405
+
406
+ return image
407
+
408
+
409
+ def rgb_to_ycbcr_torch(tensor: Tensor, only_use_y_channel: bool) -> Tensor:
410
+ """Implementation of rgb2ycbcr function in Matlab under PyTorch
411
+
412
+ References from:`https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion`
413
+
414
+ Args:
415
+ tensor (Tensor): Image data in PyTorch format
416
+ only_use_y_channel (bool): Extract only Y channel
417
+
418
+ Returns:
419
+ tensor (Tensor): YCbCr image data in PyTorch format
420
+
421
+ """
422
+ if only_use_y_channel:
423
+ weight = Tensor([[65.481], [128.553], [24.966]]).to(tensor)
424
+ tensor = torch.matmul(tensor.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0
425
+ else:
426
+ weight = Tensor([[65.481, -37.797, 112.0],
427
+ [128.553, -74.203, -93.786],
428
+ [24.966, 112.0, -18.214]]).to(tensor)
429
+ bias = Tensor([16, 128, 128]).view(1, 3, 1, 1).to(tensor)
430
+ tensor = torch.matmul(tensor.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias
431
+
432
+ tensor /= 255.
433
+
434
+ return tensor
435
+
436
+
437
+ def bgr_to_ycbcr_torch(tensor: Tensor, only_use_y_channel: bool) -> Tensor:
438
+ """Implementation of bgr2ycbcr function in Matlab under PyTorch
439
+
440
+ References from:`https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion`
441
+
442
+ Args:
443
+ tensor (Tensor): Image data in PyTorch format
444
+ only_use_y_channel (bool): Extract only Y channel
445
+
446
+ Returns:
447
+ tensor (Tensor): YCbCr image data in PyTorch format
448
+
449
+ """
450
+ if only_use_y_channel:
451
+ weight = Tensor([[24.966], [128.553], [65.481]]).to(tensor)
452
+ tensor = torch.matmul(tensor.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0
453
+ else:
454
+ weight = Tensor([[24.966, 112.0, -18.214],
455
+ [128.553, -74.203, -93.786],
456
+ [65.481, -37.797, 112.0]]).to(tensor)
457
+ bias = Tensor([16, 128, 128]).view(1, 3, 1, 1).to(tensor)
458
+ tensor = torch.matmul(tensor.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias
459
+
460
+ tensor /= 255.
461
+
462
+ return tensor
463
+
464
+
465
+ def center_crop(image: np.ndarray, image_size: int) -> np.ndarray:
466
+ """Crop small image patches from one image center area.
467
+
468
+ Args:
469
+ image (np.ndarray): The input image for `OpenCV.imread`.
470
+ image_size (int): The size of the captured image area.
471
+
472
+ Returns:
473
+ patch_image (np.ndarray): Small patch image
474
+
475
+ """
476
+ image_height, image_width = image.shape[:2]
477
+
478
+ # Just need to find the top and left coordinates of the image
479
+ top = (image_height - image_size) // 2
480
+ left = (image_width - image_size) // 2
481
+
482
+ # Crop image patch
483
+ patch_image = image[top:top + image_size, left:left + image_size, ...]
484
+
485
+ return patch_image
486
+
487
+
488
+ def random_crop(image: np.ndarray, image_size: int) -> np.ndarray:
489
+ """Crop small image patches from one image.
490
+
491
+ Args:
492
+ image (np.ndarray): The input image for `OpenCV.imread`.
493
+ image_size (int): The size of the captured image area.
494
+
495
+ Returns:
496
+ patch_image (np.ndarray): Small patch image
497
+
498
+ """
499
+ image_height, image_width = image.shape[:2]
500
+
501
+ # Just need to find the top and left coordinates of the image
502
+ top = random.randint(0, image_height - image_size)
503
+ left = random.randint(0, image_width - image_size)
504
+
505
+ # Crop image patch
506
+ patch_image = image[top:top + image_size, left:left + image_size, ...]
507
+
508
+ return patch_image
509
+
510
+
511
+ def random_rotate(image,
512
+ angles: list,
513
+ center: tuple[int, int] = None,
514
+ scale_factor: float = 1.0) -> np.ndarray:
515
+ """Rotate an image by a random angle
516
+
517
+ Args:
518
+ image (np.ndarray): Image read with OpenCV
519
+ angles (list): Rotation angle range
520
+ center (optional, tuple[int, int]): High resolution image selection center point. Default: ``None``
521
+ scale_factor (optional, float): scaling factor. Default: 1.0
522
+
523
+ Returns:
524
+ rotated_image (np.ndarray): image after rotation
525
+
526
+ """
527
+ image_height, image_width = image.shape[:2]
528
+
529
+ if center is None:
530
+ center = (image_width // 2, image_height // 2)
531
+
532
+ # Random select specific angle
533
+ angle = random.choice(angles)
534
+ matrix = cv2.getRotationMatrix2D(center, angle, scale_factor)
535
+ rotated_image = cv2.warpAffine(image, matrix, (image_width, image_height))
536
+
537
+ return rotated_image
538
+
539
+
540
+ def random_horizontally_flip(image: np.ndarray, p: float = 0.5) -> np.ndarray:
541
+ """Flip the image upside down randomly
542
+
543
+ Args:
544
+ image (np.ndarray): Image read with OpenCV
545
+ p (optional, float): Horizontally flip probability. Default: 0.5
546
+
547
+ Returns:
548
+ horizontally_flip_image (np.ndarray): image after horizontally flip
549
+
550
+ """
551
+ if random.random() < p:
552
+ horizontally_flip_image = cv2.flip(image, 1)
553
+ else:
554
+ horizontally_flip_image = image
555
+
556
+ return horizontally_flip_image
557
+
558
+
559
+ def random_vertically_flip(image: np.ndarray, p: float = 0.5) -> np.ndarray:
560
+ """Flip an image horizontally randomly
561
+
562
+ Args:
563
+ image (np.ndarray): Image read with OpenCV
564
+ p (optional, float): Vertically flip probability. Default: 0.5
565
+
566
+ Returns:
567
+ vertically_flip_image (np.ndarray): image after vertically flip
568
+
569
+ """
570
+ if random.random() < p:
571
+ vertically_flip_image = cv2.flip(image, 0)
572
+ else:
573
+ vertically_flip_image = image
574
+
575
+ return vertically_flip_image
inference.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ import argparse
15
+ import os
16
+
17
+ import cv2
18
+ import torch
19
+ from torch import nn
20
+
21
+ import imgproc
22
+ import model
23
+ from utils import load_state_dict
24
+
25
+ model_names = sorted(
26
+ name for name in model.__dict__ if
27
+ name.islower() and not name.startswith("__") and callable(model.__dict__[name]))
28
+
29
+
30
+ def choice_device(device_type: str) -> torch.device:
31
+ # Select model processing equipment type
32
+ if device_type == "cuda":
33
+ device = torch.device("cuda", 0)
34
+ else:
35
+ device = torch.device("cpu")
36
+ return device
37
+
38
+
39
+ def build_model(model_arch_name: str, device: torch.device) -> nn.Module:
40
+ # Initialize the super-resolution model
41
+ sr_model = model.__dict__[model_arch_name](in_channels=3,
42
+ out_channels=3,
43
+ channels=64,
44
+ num_rcb=16)
45
+ sr_model = sr_model.to(device=device)
46
+
47
+ return sr_model
48
+
49
+
50
+ def main(args):
51
+ device = choice_device(args.device_type)
52
+
53
+ # Initialize the model
54
+ sr_model = build_model(args.model_arch_name, device)
55
+ print(f"Build `{args.model_arch_name}` model successfully.")
56
+
57
+ # Load model weights
58
+ sr_model = load_state_dict(sr_model, args.model_weights_path)
59
+ print(f"Load `{args.model_arch_name}` model weights `{os.path.abspath(args.model_weights_path)}` successfully.")
60
+
61
+ # Start the verification mode of the model.
62
+ sr_model.eval()
63
+
64
+ lr_tensor = imgproc.preprocess_one_image(args.inputs_path, device)
65
+
66
+ # Use the model to generate super-resolved images
67
+ with torch.no_grad():
68
+ sr_tensor = sr_model(lr_tensor)
69
+
70
+ # Save image
71
+ sr_image = imgproc.tensor_to_image(sr_tensor, False, False)
72
+ sr_image = cv2.cvtColor(sr_image, cv2.COLOR_RGB2BGR)
73
+ cv2.imwrite(args.output_path, sr_image)
74
+
75
+ print(f"SR image save to `{args.output_path}`")
76
+
77
+
78
+ if __name__ == "__main__":
79
+ parser = argparse.ArgumentParser(description="Using the model generator super-resolution images.")
80
+ parser.add_argument("--model_arch_name",
81
+ type=str,
82
+ default="srresnet_x4")
83
+ parser.add_argument("--inputs_path",
84
+ type=str,
85
+ default="./figure/comic_lr.png",
86
+ help="Low-resolution image path.")
87
+ parser.add_argument("--output_path",
88
+ type=str,
89
+ default="./figure/comic_sr.png",
90
+ help="Super-resolution image path.")
91
+ parser.add_argument("--model_weights_path",
92
+ type=str,
93
+ default="./results/pretrained_models/SRGAN_x4-ImageNet-8c4a7569.pth.tar",
94
+ help="Model weights file path.")
95
+ parser.add_argument("--device_type",
96
+ type=str,
97
+ default="cpu",
98
+ choices=["cpu", "cuda"])
99
+ args = parser.parse_args()
100
+
101
+ main(args)
model.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ import math
15
+ from typing import Any
16
+
17
+ import torch
18
+ from torch import Tensor
19
+ from torch import nn
20
+ from torch.nn import functional as F_torch
21
+ from torchvision import models
22
+ from torchvision import transforms
23
+ from torchvision.models.feature_extraction import create_feature_extractor
24
+
25
+ __all__ = [
26
+ "SRResNet", "Discriminator",
27
+ "srresnet_x4", "discriminator", "content_loss",
28
+ ]
29
+
30
+
31
+ class SRResNet(nn.Module):
32
+ def __init__(
33
+ self,
34
+ in_channels: int,
35
+ out_channels: int,
36
+ channels: int,
37
+ num_rcb: int,
38
+ upscale_factor: int
39
+ ) -> None:
40
+ super(SRResNet, self).__init__()
41
+ # Low frequency information extraction layer
42
+ self.conv1 = nn.Sequential(
43
+ nn.Conv2d(in_channels, channels, (9, 9), (1, 1), (4, 4)),
44
+ nn.PReLU(),
45
+ )
46
+
47
+ # High frequency information extraction block
48
+ trunk = []
49
+ for _ in range(num_rcb):
50
+ trunk.append(_ResidualConvBlock(channels))
51
+ self.trunk = nn.Sequential(*trunk)
52
+
53
+ # High-frequency information linear fusion layer
54
+ self.conv2 = nn.Sequential(
55
+ nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False),
56
+ nn.BatchNorm2d(channels),
57
+ )
58
+
59
+ # zoom block
60
+ upsampling = []
61
+ if upscale_factor == 2 or upscale_factor == 4 or upscale_factor == 8:
62
+ for _ in range(int(math.log(upscale_factor, 2))):
63
+ upsampling.append(_UpsampleBlock(channels, 2))
64
+ elif upscale_factor == 3:
65
+ upsampling.append(_UpsampleBlock(channels, 3))
66
+ self.upsampling = nn.Sequential(*upsampling)
67
+
68
+ # reconstruction block
69
+ self.conv3 = nn.Conv2d(channels, out_channels, (9, 9), (1, 1), (4, 4))
70
+
71
+ # Initialize neural network weights
72
+ self._initialize_weights()
73
+
74
+ def forward(self, x: Tensor) -> Tensor:
75
+ return self._forward_impl(x)
76
+
77
+ # Support torch.script function
78
+ def _forward_impl(self, x: Tensor) -> Tensor:
79
+ out1 = self.conv1(x)
80
+ out = self.trunk(out1)
81
+ out2 = self.conv2(out)
82
+ out = torch.add(out1, out2)
83
+ out = self.upsampling(out)
84
+ out = self.conv3(out)
85
+
86
+ out = torch.clamp_(out, 0.0, 1.0)
87
+
88
+ return out
89
+
90
+ def _initialize_weights(self) -> None:
91
+ for module in self.modules():
92
+ if isinstance(module, nn.Conv2d):
93
+ nn.init.kaiming_normal_(module.weight)
94
+ if module.bias is not None:
95
+ nn.init.constant_(module.bias, 0)
96
+ elif isinstance(module, nn.BatchNorm2d):
97
+ nn.init.constant_(module.weight, 1)
98
+
99
+
100
+ class Discriminator(nn.Module):
101
+ def __init__(self) -> None:
102
+ super(Discriminator, self).__init__()
103
+ self.features = nn.Sequential(
104
+ # input size. (3) x 96 x 96
105
+ nn.Conv2d(3, 64, (3, 3), (1, 1), (1, 1), bias=True),
106
+ nn.LeakyReLU(0.2, True),
107
+ # state size. (64) x 48 x 48
108
+ nn.Conv2d(64, 64, (3, 3), (2, 2), (1, 1), bias=False),
109
+ nn.BatchNorm2d(64),
110
+ nn.LeakyReLU(0.2, True),
111
+ nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1), bias=False),
112
+ nn.BatchNorm2d(128),
113
+ nn.LeakyReLU(0.2, True),
114
+ # state size. (128) x 24 x 24
115
+ nn.Conv2d(128, 128, (3, 3), (2, 2), (1, 1), bias=False),
116
+ nn.BatchNorm2d(128),
117
+ nn.LeakyReLU(0.2, True),
118
+ nn.Conv2d(128, 256, (3, 3), (1, 1), (1, 1), bias=False),
119
+ nn.BatchNorm2d(256),
120
+ nn.LeakyReLU(0.2, True),
121
+ # state size. (256) x 12 x 12
122
+ nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1), bias=False),
123
+ nn.BatchNorm2d(256),
124
+ nn.LeakyReLU(0.2, True),
125
+ nn.Conv2d(256, 512, (3, 3), (1, 1), (1, 1), bias=False),
126
+ nn.BatchNorm2d(512),
127
+ nn.LeakyReLU(0.2, True),
128
+ # state size. (512) x 6 x 6
129
+ nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), bias=False),
130
+ nn.BatchNorm2d(512),
131
+ nn.LeakyReLU(0.2, True),
132
+ )
133
+
134
+ self.classifier = nn.Sequential(
135
+ nn.Linear(512 * 6 * 6, 1024),
136
+ nn.LeakyReLU(0.2, True),
137
+ nn.Linear(1024, 1),
138
+ )
139
+
140
+ def forward(self, x: Tensor) -> Tensor:
141
+ # Input image size must equal 96
142
+ assert x.shape[2] == 96 and x.shape[3] == 96, "Image shape must equal 96x96"
143
+
144
+ out = self.features(x)
145
+ out = torch.flatten(out, 1)
146
+ out = self.classifier(out)
147
+
148
+ return out
149
+
150
+
151
+ class _ResidualConvBlock(nn.Module):
152
+ def __init__(self, channels: int) -> None:
153
+ super(_ResidualConvBlock, self).__init__()
154
+ self.rcb = nn.Sequential(
155
+ nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False),
156
+ nn.BatchNorm2d(channels),
157
+ nn.PReLU(),
158
+ nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False),
159
+ nn.BatchNorm2d(channels),
160
+ )
161
+
162
+ def forward(self, x: Tensor) -> Tensor:
163
+ identity = x
164
+
165
+ out = self.rcb(x)
166
+
167
+ out = torch.add(out, identity)
168
+
169
+ return out
170
+
171
+
172
+ class _UpsampleBlock(nn.Module):
173
+ def __init__(self, channels: int, upscale_factor: int) -> None:
174
+ super(_UpsampleBlock, self).__init__()
175
+ self.upsample_block = nn.Sequential(
176
+ nn.Conv2d(channels, channels * upscale_factor * upscale_factor, (3, 3), (1, 1), (1, 1)),
177
+ nn.PixelShuffle(2),
178
+ nn.PReLU(),
179
+ )
180
+
181
+ def forward(self, x: Tensor) -> Tensor:
182
+ out = self.upsample_block(x)
183
+
184
+ return out
185
+
186
+
187
+ class _ContentLoss(nn.Module):
188
+ """Constructs a content loss function based on the VGG19 network.
189
+ Using high-level feature mapping layers from the latter layers will focus more on the texture content of the image.
190
+
191
+ Paper reference list:
192
+ -`Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network <https://arxiv.org/pdf/1609.04802.pdf>` paper.
193
+ -`ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks <https://arxiv.org/pdf/1809.00219.pdf>` paper.
194
+ -`Perceptual Extreme Super Resolution Network with Receptive Field Block <https://arxiv.org/pdf/2005.12597.pdf>` paper.
195
+
196
+ """
197
+
198
+ def __init__(
199
+ self,
200
+ feature_model_extractor_node: str,
201
+ feature_model_normalize_mean: list,
202
+ feature_model_normalize_std: list
203
+ ) -> None:
204
+ super(_ContentLoss, self).__init__()
205
+ # Get the name of the specified feature extraction node
206
+ self.feature_model_extractor_node = feature_model_extractor_node
207
+ # Load the VGG19 model trained on the ImageNet dataset.
208
+ model = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1)
209
+ # Extract the thirty-sixth layer output in the VGG19 model as the content loss.
210
+ self.feature_extractor = create_feature_extractor(model, [feature_model_extractor_node])
211
+ # set to validation mode
212
+ self.feature_extractor.eval()
213
+
214
+ # The preprocessing method of the input data.
215
+ # This is the VGG model preprocessing method of the ImageNet dataset.
216
+ self.normalize = transforms.Normalize(feature_model_normalize_mean, feature_model_normalize_std)
217
+
218
+ # Freeze model parameters.
219
+ for model_parameters in self.feature_extractor.parameters():
220
+ model_parameters.requires_grad = False
221
+
222
+ def forward(self, sr_tensor: Tensor, gt_tensor: Tensor) -> Tensor:
223
+ # Standardized operations
224
+ sr_tensor = self.normalize(sr_tensor)
225
+ gt_tensor = self.normalize(gt_tensor)
226
+
227
+ sr_feature = self.feature_extractor(sr_tensor)[self.feature_model_extractor_node]
228
+ gt_feature = self.feature_extractor(gt_tensor)[self.feature_model_extractor_node]
229
+
230
+ # Find the feature map difference between the two images
231
+ loss = F_torch.mse_loss(sr_feature, gt_feature)
232
+
233
+ return loss
234
+
235
+
236
+ def srresnet_x4(**kwargs: Any) -> SRResNet:
237
+ model = SRResNet(upscale_factor=4, **kwargs)
238
+
239
+ return model
240
+
241
+
242
+ def discriminator() -> Discriminator:
243
+ model = Discriminator()
244
+
245
+ return model
246
+
247
+
248
+ def content_loss(**kwargs: Any) -> _ContentLoss:
249
+ content_loss = _ContentLoss(**kwargs)
250
+
251
+ return content_loss
utils.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ import os
15
+ import shutil
16
+ from enum import Enum
17
+ from typing import Any
18
+
19
+ import torch
20
+ from torch import nn
21
+ from torch.nn import Module
22
+ from torch.optim import Optimizer
23
+
24
+ __all__ = [
25
+ "load_state_dict", "make_directory", "save_checkpoint",
26
+ "Summary", "AverageMeter", "ProgressMeter"
27
+ ]
28
+
29
+
30
+ def load_state_dict(
31
+ model: nn.Module,
32
+ model_weights_path: str,
33
+ ema_model: nn.Module = None,
34
+ optimizer: torch.optim.Optimizer = None,
35
+ scheduler: torch.optim.lr_scheduler = None,
36
+ load_mode: str = None,
37
+ ) -> tuple[Module, Module, Any, Any, Any, Optimizer | None, Any] | tuple[Module, Any, Any, Any, Optimizer | None, Any] | Module:
38
+ # Load model weights
39
+ checkpoint = torch.load(model_weights_path, map_location=lambda storage, loc: storage)
40
+
41
+ if load_mode == "resume":
42
+ # Restore the parameters in the training node to this point
43
+ start_epoch = checkpoint["epoch"]
44
+ best_psnr = checkpoint["best_psnr"]
45
+ best_ssim = checkpoint["best_ssim"]
46
+ # Load model state dict. Extract the fitted model weights
47
+ model_state_dict = model.state_dict()
48
+ state_dict = {k: v for k, v in checkpoint["state_dict"].items() if k in model_state_dict.keys()}
49
+ # Overwrite the model weights to the current model (base model)
50
+ model_state_dict.update(state_dict)
51
+ model.load_state_dict(model_state_dict)
52
+ # Load the optimizer model
53
+ optimizer.load_state_dict(checkpoint["optimizer"])
54
+
55
+ if scheduler is not None:
56
+ # Load the scheduler model
57
+ scheduler.load_state_dict(checkpoint["scheduler"])
58
+
59
+ if ema_model is not None:
60
+ # Load ema model state dict. Extract the fitted model weights
61
+ ema_model_state_dict = ema_model.state_dict()
62
+ ema_state_dict = {k: v for k, v in checkpoint["ema_state_dict"].items() if k in ema_model_state_dict.keys()}
63
+ # Overwrite the model weights to the current model (ema model)
64
+ ema_model_state_dict.update(ema_state_dict)
65
+ ema_model.load_state_dict(ema_model_state_dict)
66
+
67
+ return model, ema_model, start_epoch, best_psnr, best_ssim, optimizer, scheduler
68
+ else:
69
+ # Load model state dict. Extract the fitted model weights
70
+ model_state_dict = model.state_dict()
71
+ state_dict = {k: v for k, v in checkpoint["state_dict"].items() if
72
+ k in model_state_dict.keys() and v.size() == model_state_dict[k].size()}
73
+ # Overwrite the model weights to the current model
74
+ model_state_dict.update(state_dict)
75
+ model.load_state_dict(model_state_dict)
76
+
77
+ return model
78
+
79
+
80
+ def make_directory(dir_path: str) -> None:
81
+ if not os.path.exists(dir_path):
82
+ os.makedirs(dir_path)
83
+
84
+
85
+ def save_checkpoint(
86
+ state_dict: dict,
87
+ file_name: str,
88
+ samples_dir: str,
89
+ results_dir: str,
90
+ best_file_name: str,
91
+ last_file_name: str,
92
+ is_best: bool = False,
93
+ is_last: bool = False,
94
+ ) -> None:
95
+ checkpoint_path = os.path.join(samples_dir, file_name)
96
+ torch.save(state_dict, checkpoint_path)
97
+
98
+ if is_best:
99
+ shutil.copyfile(checkpoint_path, os.path.join(results_dir, best_file_name))
100
+ if is_last:
101
+ shutil.copyfile(checkpoint_path, os.path.join(results_dir, last_file_name))
102
+
103
+
104
+ class Summary(Enum):
105
+ NONE = 0
106
+ AVERAGE = 1
107
+ SUM = 2
108
+ COUNT = 3
109
+
110
+
111
+ class AverageMeter(object):
112
+ def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE):
113
+ self.name = name
114
+ self.fmt = fmt
115
+ self.summary_type = summary_type
116
+ self.reset()
117
+
118
+ def reset(self):
119
+ self.val = 0
120
+ self.avg = 0
121
+ self.sum = 0
122
+ self.count = 0
123
+
124
+ def update(self, val, n=1):
125
+ self.val = val
126
+ self.sum += val * n
127
+ self.count += n
128
+ self.avg = self.sum / self.count
129
+
130
+ def __str__(self):
131
+ fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
132
+ return fmtstr.format(**self.__dict__)
133
+
134
+ def summary(self):
135
+ if self.summary_type is Summary.NONE:
136
+ fmtstr = ""
137
+ elif self.summary_type is Summary.AVERAGE:
138
+ fmtstr = "{name} {avg:.2f}"
139
+ elif self.summary_type is Summary.SUM:
140
+ fmtstr = "{name} {sum:.2f}"
141
+ elif self.summary_type is Summary.COUNT:
142
+ fmtstr = "{name} {count:.2f}"
143
+ else:
144
+ raise ValueError(f"Invalid summary type {self.summary_type}")
145
+
146
+ return fmtstr.format(**self.__dict__)
147
+
148
+
149
+ class ProgressMeter(object):
150
+ def __init__(self, num_batches, meters, prefix=""):
151
+ self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
152
+ self.meters = meters
153
+ self.prefix = prefix
154
+
155
+ def display(self, batch):
156
+ entries = [self.prefix + self.batch_fmtstr.format(batch)]
157
+ entries += [str(meter) for meter in self.meters]
158
+ print("\t".join(entries))
159
+
160
+ def display_summary(self):
161
+ entries = [" *"]
162
+ entries += [meter.summary() for meter in self.meters]
163
+ print(" ".join(entries))
164
+
165
+ def _get_batch_fmtstr(self, num_batches):
166
+ num_digits = len(str(num_batches // 1))
167
+ fmt = "{:" + str(num_digits) + "d}"
168
+ return "[" + fmt + "/" + fmt.format(num_batches) + "]"
weights/SRGAN_x4-ImageNet-8c4a7569.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c5431b5921e1509190aed6aca02c7d5838f4805e8ae6f9fa08c140260b6a2a3
3
+ size 6285796