Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +139 -0
- LICENSE +201 -0
- MODEL_ZOO.md +13 -0
- README.md +201 -7
- VERSION +1 -0
- basicsr/__init__.py +15 -0
- basicsr/archs/__init__.py +25 -0
- basicsr/archs/ddcolor_arch.py +385 -0
- basicsr/archs/ddcolor_arch_utils/__int__.py +0 -0
- basicsr/archs/ddcolor_arch_utils/convnext.py +155 -0
- basicsr/archs/ddcolor_arch_utils/position_encoding.py +52 -0
- basicsr/archs/ddcolor_arch_utils/transformer.py +368 -0
- basicsr/archs/ddcolor_arch_utils/transformer_utils.py +192 -0
- basicsr/archs/ddcolor_arch_utils/unet.py +208 -0
- basicsr/archs/ddcolor_arch_utils/util.py +63 -0
- basicsr/archs/discriminator_arch.py +28 -0
- basicsr/archs/vgg_arch.py +165 -0
- basicsr/data/__init__.py +101 -0
- basicsr/data/data_sampler.py +48 -0
- basicsr/data/data_util.py +313 -0
- basicsr/data/fmix.py +206 -0
- basicsr/data/lab_dataset.py +159 -0
- basicsr/data/prefetch_dataloader.py +125 -0
- basicsr/data/transforms.py +192 -0
- basicsr/losses/__init__.py +26 -0
- basicsr/losses/loss_util.py +95 -0
- basicsr/losses/losses.py +551 -0
- basicsr/metrics/__init__.py +20 -0
- basicsr/metrics/colorfulness.py +17 -0
- basicsr/metrics/custom_fid.py +260 -0
- basicsr/metrics/metric_util.py +45 -0
- basicsr/metrics/psnr_ssim.py +128 -0
- basicsr/models/__init__.py +30 -0
- basicsr/models/base_model.py +382 -0
- basicsr/models/color_model.py +369 -0
- basicsr/models/lr_scheduler.py +96 -0
- basicsr/train.py +224 -0
- basicsr/utils/__init__.py +37 -0
- basicsr/utils/color_enhance.py +9 -0
- basicsr/utils/diffjpeg.py +515 -0
- basicsr/utils/dist_util.py +82 -0
- basicsr/utils/download_util.py +64 -0
- basicsr/utils/face_util.py +192 -0
- basicsr/utils/file_client.py +167 -0
- basicsr/utils/flow_util.py +170 -0
- basicsr/utils/img_process_util.py +83 -0
- basicsr/utils/img_util.py +227 -0
- basicsr/utils/lmdb_util.py +196 -0
- basicsr/utils/logger.py +209 -0
- basicsr/utils/matlab_functions.py +359 -0
.gitignore
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ignored folders
|
2 |
+
datasets/*
|
3 |
+
experiments/*
|
4 |
+
results/*
|
5 |
+
tb_logger/*
|
6 |
+
wandb/*
|
7 |
+
tmp/*
|
8 |
+
|
9 |
+
*.DS_Store
|
10 |
+
.idea
|
11 |
+
.vscode
|
12 |
+
.github
|
13 |
+
|
14 |
+
.onnx
|
15 |
+
|
16 |
+
# ignored files
|
17 |
+
version.py
|
18 |
+
|
19 |
+
# ignored files with suffix
|
20 |
+
*.html
|
21 |
+
*.png
|
22 |
+
*.jpeg
|
23 |
+
*.jpg
|
24 |
+
*.gif
|
25 |
+
*.pth
|
26 |
+
*.zip
|
27 |
+
*.npy
|
28 |
+
|
29 |
+
# template
|
30 |
+
|
31 |
+
# Byte-compiled / optimized / DLL files
|
32 |
+
__pycache__/
|
33 |
+
*.py[cod]
|
34 |
+
*$py.class
|
35 |
+
|
36 |
+
# C extensions
|
37 |
+
*.so
|
38 |
+
|
39 |
+
# Distribution / packaging
|
40 |
+
.Python
|
41 |
+
build/
|
42 |
+
develop-eggs/
|
43 |
+
dist/
|
44 |
+
downloads/
|
45 |
+
eggs/
|
46 |
+
.eggs/
|
47 |
+
lib/
|
48 |
+
lib64/
|
49 |
+
parts/
|
50 |
+
sdist/
|
51 |
+
var/
|
52 |
+
wheels/
|
53 |
+
*.egg-info/
|
54 |
+
.installed.cfg
|
55 |
+
*.egg
|
56 |
+
MANIFEST
|
57 |
+
|
58 |
+
# PyInstaller
|
59 |
+
# Usually these files are written by a python script from a template
|
60 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
61 |
+
*.manifest
|
62 |
+
*.spec
|
63 |
+
|
64 |
+
# Installer logs
|
65 |
+
pip-log.txt
|
66 |
+
pip-delete-this-directory.txt
|
67 |
+
|
68 |
+
# Unit test / coverage reports
|
69 |
+
htmlcov/
|
70 |
+
.tox/
|
71 |
+
.coverage
|
72 |
+
.coverage.*
|
73 |
+
.cache
|
74 |
+
nosetests.xml
|
75 |
+
coverage.xml
|
76 |
+
*.cover
|
77 |
+
.hypothesis/
|
78 |
+
.pytest_cache/
|
79 |
+
|
80 |
+
# Translations
|
81 |
+
*.mo
|
82 |
+
*.pot
|
83 |
+
|
84 |
+
# Django stuff:
|
85 |
+
*.log
|
86 |
+
local_settings.py
|
87 |
+
db.sqlite3
|
88 |
+
|
89 |
+
# Flask stuff:
|
90 |
+
instance/
|
91 |
+
.webassets-cache
|
92 |
+
|
93 |
+
# Scrapy stuff:
|
94 |
+
.scrapy
|
95 |
+
|
96 |
+
# Sphinx documentation
|
97 |
+
docs/_build/
|
98 |
+
|
99 |
+
# PyBuilder
|
100 |
+
target/
|
101 |
+
|
102 |
+
# Jupyter Notebook
|
103 |
+
.ipynb_checkpoints
|
104 |
+
|
105 |
+
# pyenv
|
106 |
+
.python-version
|
107 |
+
|
108 |
+
# celery beat schedule file
|
109 |
+
celerybeat-schedule
|
110 |
+
|
111 |
+
# SageMath parsed files
|
112 |
+
*.sage.py
|
113 |
+
|
114 |
+
# Environments
|
115 |
+
.env
|
116 |
+
.venv
|
117 |
+
env/
|
118 |
+
venv/
|
119 |
+
ENV/
|
120 |
+
env.bak/
|
121 |
+
venv.bak/
|
122 |
+
|
123 |
+
# Spyder project settings
|
124 |
+
.spyderproject
|
125 |
+
.spyproject
|
126 |
+
|
127 |
+
# Rope project settings
|
128 |
+
.ropeproject
|
129 |
+
|
130 |
+
# mkdocs documentation
|
131 |
+
/site
|
132 |
+
|
133 |
+
# mypy
|
134 |
+
.mypy_cache/
|
135 |
+
|
136 |
+
# meta file
|
137 |
+
data_list/*.txt
|
138 |
+
|
139 |
+
weights/
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
MODEL_ZOO.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## DDColor Model Zoo
|
2 |
+
|
3 |
+
| Model | Description | Note |
|
4 |
+
| ---------------------- | :------------------ | :-----|
|
5 |
+
| [ddcolor_paper.pth](https://huggingface.co/piddnad/DDColor-models/resolve/main/ddcolor_paper.pth) | DDColor-L trained on ImageNet | paper model, use it only if you want to reproduce some of the images in the paper.
|
6 |
+
| [ddcolor_modelscope.pth](https://huggingface.co/piddnad/DDColor-models/resolve/main/ddcolor_modelscope.pth) (***default***) | DDColor-L trained on ImageNet | We trained this model using the same data cleaning scheme as [BigColor](https://github.com/KIMGEONUNG/BigColor/issues/2#issuecomment-1196287574), so it can get the best qualitative results with little degrading FID performance. Use this model by default if you want to test images outside the ImageNet. It can also be easily downloaded through ModelScope [in this way](README.md#inference-with-modelscope-library).
|
7 |
+
| [ddcolor_artistic.pth](https://huggingface.co/piddnad/DDColor-models/resolve/main/ddcolor_artistic.pth) | DDColor-L trained on ImageNet + private data | We trained this model with an extended dataset containing many high-quality artistic images. Also, we didn't use colorfulness loss during training, so there may be fewer unreasonable color artifacts. Use this model if you want to try different colorization results.
|
8 |
+
| [ddcolor_paper_tiny.pth](https://huggingface.co/piddnad/DDColor-models/resolve/main/ddcolor_paper_tiny.pth) | DDColor-T trained on ImageNet | The most lightweight version of ddcolor model, using the same training scheme as ddcolor_paper.
|
9 |
+
|
10 |
+
## Discussions
|
11 |
+
|
12 |
+
* About Colorfulness Loss (CL): CL can encourage more "colorful" results and help improve CF scores, however, it sometimes leads to the generation of unpleasant color blocks (eg. red color artifacts). If something goes wrong, I personally recommend trying to remove it during training.
|
13 |
+
|
README.md
CHANGED
@@ -1,12 +1,206 @@
|
|
1 |
---
|
2 |
title: DDColor
|
3 |
-
|
4 |
-
colorFrom: gray
|
5 |
-
colorTo: pink
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 5.
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
title: DDColor
|
3 |
+
app_file: gradio_app.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
+
sdk_version: 5.21.0
|
|
|
|
|
6 |
---
|
7 |
+
# 🎨 DDColor
|
8 |
+
[](https://arxiv.org/abs/2212.11613)
|
9 |
+
[](https://huggingface.co/piddnad/DDColor-models)
|
10 |
+
[](https://www.modelscope.cn/models/damo/cv_ddcolor_image-colorization/summary)
|
11 |
+
[](https://replicate.com/piddnad/ddcolor)
|
12 |
+

|
13 |
|
14 |
+
Official PyTorch implementation of ICCV 2023 Paper "DDColor: Towards Photo-Realistic Image Colorization via Dual Decoders".
|
15 |
+
|
16 |
+
> Xiaoyang Kang, Tao Yang, Wenqi Ouyang, Peiran Ren, Lingzhi Li, Xuansong Xie
|
17 |
+
> *DAMO Academy, Alibaba Group*
|
18 |
+
|
19 |
+
🪄 DDColor can provide vivid and natural colorization for historical black and white old photos.
|
20 |
+
|
21 |
+
<p align="center">
|
22 |
+
<img src="assets/teaser.png" width="100%">
|
23 |
+
</p>
|
24 |
+
|
25 |
+
🎲 It can even colorize/recolor landscapes from anime games, transforming your animated scenery into a realistic real-life style! (Image source: Genshin Impact)
|
26 |
+
|
27 |
+
<p align="center">
|
28 |
+
<img src="assets/anime_landscapes.png" width="100%">
|
29 |
+
</p>
|
30 |
+
|
31 |
+
|
32 |
+
## News
|
33 |
+
- [2024-01-28] Support inference via 🤗 Hugging Face! Thanks @[Niels](https://github.com/NielsRogge) for the suggestion and example code and @[Skwara](https://github.com/Skwarson96) for fixing bug.
|
34 |
+
- [2024-01-18] Add Replicate demo and API! Thanks @[Chenxi](https://github.com/chenxwh).
|
35 |
+
- [2023-12-13] Release the DDColor-tiny pre-trained model!
|
36 |
+
- [2023-09-07] Add the Model Zoo and release three pretrained models!
|
37 |
+
- [2023-05-15] Code release for training and inference!
|
38 |
+
- [2023-05-05] The online demo is available!
|
39 |
+
|
40 |
+
|
41 |
+
## Online Demo
|
42 |
+
Try our online demos at [ModelScope](https://www.modelscope.cn/models/damo/cv_ddcolor_image-colorization/summary) and [Replicate](https://replicate.com/piddnad/ddcolor).
|
43 |
+
|
44 |
+
|
45 |
+
## Methods
|
46 |
+
*In short:* DDColor uses multi-scale visual features to optimize **learnable color tokens** (i.e. color queries) and achieves state-of-the-art performance on automatic image colorization.
|
47 |
+
|
48 |
+
<p align="center">
|
49 |
+
<img src="assets/network_arch.jpg" width="100%">
|
50 |
+
</p>
|
51 |
+
|
52 |
+
|
53 |
+
## Installation
|
54 |
+
### Requirements
|
55 |
+
- Python >= 3.7
|
56 |
+
- PyTorch >= 1.7
|
57 |
+
|
58 |
+
### Installation with conda (recommended)
|
59 |
+
|
60 |
+
```sh
|
61 |
+
conda create -n ddcolor python=3.9
|
62 |
+
conda activate ddcolor
|
63 |
+
pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu118
|
64 |
+
|
65 |
+
pip install -r requirements.txt
|
66 |
+
|
67 |
+
# Install basicsr, only required for training
|
68 |
+
python3 setup.py develop
|
69 |
+
```
|
70 |
+
|
71 |
+
## Quick Start
|
72 |
+
### Inference Using Local Script (No `basicsr` Required)
|
73 |
+
1. Download the pretrained model:
|
74 |
+
|
75 |
+
```python
|
76 |
+
from modelscope.hub.snapshot_download import snapshot_download
|
77 |
+
|
78 |
+
model_dir = snapshot_download('damo/cv_ddcolor_image-colorization', cache_dir='./modelscope')
|
79 |
+
print('model assets saved to %s' % model_dir)
|
80 |
+
```
|
81 |
+
|
82 |
+
2. Run inference with
|
83 |
+
|
84 |
+
```sh
|
85 |
+
python infer.py --model_path ./modelscope/damo/cv_ddcolor_image-colorization/pytorch_model.pt --input ./assets/test_images
|
86 |
+
```
|
87 |
+
or
|
88 |
+
```sh
|
89 |
+
sh scripts/inference.sh
|
90 |
+
```
|
91 |
+
|
92 |
+
### Inference Using Hugging Face
|
93 |
+
Load the model via Hugging Face Hub:
|
94 |
+
|
95 |
+
```python
|
96 |
+
from infer_hf import DDColorHF
|
97 |
+
|
98 |
+
ddcolor_paper_tiny = DDColorHF.from_pretrained("piddnad/ddcolor_paper_tiny")
|
99 |
+
ddcolor_paper = DDColorHF.from_pretrained("piddnad/ddcolor_paper")
|
100 |
+
ddcolor_modelscope = DDColorHF.from_pretrained("piddnad/ddcolor_modelscope")
|
101 |
+
ddcolor_artistic = DDColorHF.from_pretrained("piddnad/ddcolor_artistic")
|
102 |
+
```
|
103 |
+
|
104 |
+
Check `infer_hf.py` for the details of the inference, or directly perform model inference by running:
|
105 |
+
|
106 |
+
```sh
|
107 |
+
python infer_hf.py --model_name ddcolor_modelscope --input ./assets/test_images
|
108 |
+
# model_name: [ddcolor_paper | ddcolor_modelscope | ddcolor_artistic | ddcolor_paper_tiny]
|
109 |
+
```
|
110 |
+
|
111 |
+
### Inference Using ModelScope
|
112 |
+
1. Install modelscope:
|
113 |
+
|
114 |
+
```sh
|
115 |
+
pip install modelscope
|
116 |
+
```
|
117 |
+
|
118 |
+
2. Run inference:
|
119 |
+
|
120 |
+
```python
|
121 |
+
import cv2
|
122 |
+
from modelscope.outputs import OutputKeys
|
123 |
+
from modelscope.pipelines import pipeline
|
124 |
+
from modelscope.utils.constant import Tasks
|
125 |
+
|
126 |
+
img_colorization = pipeline(Tasks.image_colorization, model='damo/cv_ddcolor_image-colorization')
|
127 |
+
result = img_colorization('https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/audrey_hepburn.jpg')
|
128 |
+
cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG])
|
129 |
+
```
|
130 |
+
|
131 |
+
This code will automatically download the `ddcolor_modelscope` model (see [ModelZoo](#model-zoo)) and performs inference. The model file `pytorch_model.pt` can be found in the local path `~/.cache/modelscope/hub/damo`.
|
132 |
+
|
133 |
+
### Gradio Demo
|
134 |
+
Install the gradio and other required libraries:
|
135 |
+
|
136 |
+
```sh
|
137 |
+
pip install gradio gradio_imageslider timm
|
138 |
+
```
|
139 |
+
|
140 |
+
Then, you can run the demo with the following command:
|
141 |
+
|
142 |
+
```sh
|
143 |
+
python gradio_app.py
|
144 |
+
```
|
145 |
+
|
146 |
+
## Model Zoo
|
147 |
+
We provide several different versions of pretrained models, please check out [Model Zoo](MODEL_ZOO.md).
|
148 |
+
|
149 |
+
|
150 |
+
## Train
|
151 |
+
1. Dataset Preparation: Download the [ImageNet](https://www.image-net.org/) dataset or create a custom dataset. Use this script to obtain the dataset list file:
|
152 |
+
|
153 |
+
```sh
|
154 |
+
python data_list/get_meta_file.py
|
155 |
+
```
|
156 |
+
|
157 |
+
2. Download the pretrained weights for [ConvNeXt](https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth) and [InceptionV3](https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth) and place them in the `pretrain` folder.
|
158 |
+
|
159 |
+
3. Specify 'meta_info_file' and other options in `options/train/train_ddcolor.yml`.
|
160 |
+
|
161 |
+
4. Start training:
|
162 |
+
|
163 |
+
```sh
|
164 |
+
sh scripts/train.sh
|
165 |
+
```
|
166 |
+
|
167 |
+
## ONNX export
|
168 |
+
Support for ONNX model exports is available.
|
169 |
+
|
170 |
+
1. Install dependencies:
|
171 |
+
|
172 |
+
```sh
|
173 |
+
pip install onnx==1.16.1 onnxruntime==1.19.2 onnxsim==0.4.36
|
174 |
+
```
|
175 |
+
|
176 |
+
2. Usage example:
|
177 |
+
|
178 |
+
```sh
|
179 |
+
python export.py
|
180 |
+
usage: export.py [-h] [--input_size INPUT_SIZE] [--batch_size BATCH_SIZE] --model_path MODEL_PATH [--model_size MODEL_SIZE]
|
181 |
+
[--decoder_type DECODER_TYPE] [--export_path EXPORT_PATH] [--opset OPSET]
|
182 |
+
```
|
183 |
+
|
184 |
+
Demo of ONNX export using a `ddcolor_paper_tiny` model is available [here](notebooks/colorization_pipeline_onnxruntime.ipynb).
|
185 |
+
|
186 |
+
|
187 |
+
## Citation
|
188 |
+
|
189 |
+
If our work is helpful for your research, please consider citing:
|
190 |
+
|
191 |
+
```
|
192 |
+
@inproceedings{kang2023ddcolor,
|
193 |
+
title={DDColor: Towards Photo-Realistic Image Colorization via Dual Decoders},
|
194 |
+
author={Kang, Xiaoyang and Yang, Tao and Ouyang, Wenqi and Ren, Peiran and Li, Lingzhi and Xie, Xuansong},
|
195 |
+
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
|
196 |
+
pages={328--338},
|
197 |
+
year={2023}
|
198 |
+
}
|
199 |
+
```
|
200 |
+
|
201 |
+
## Acknowledgments
|
202 |
+
We thank the authors of BasicSR for the awesome training pipeline.
|
203 |
+
|
204 |
+
> Xintao Wang, Ke Yu, Kelvin C.K. Chan, Chao Dong and Chen Change Loy. BasicSR: Open Source Image and Video Restoration Toolbox. https://github.com/xinntao/BasicSR, 2020.
|
205 |
+
|
206 |
+
Some codes are adapted from [ColorFormer](https://github.com/jixiaozhong/ColorFormer), [BigColor](https://github.com/KIMGEONUNG/BigColor), [ConvNeXt](https://github.com/facebookresearch/ConvNeXt), [Mask2Former](https://github.com/facebookresearch/Mask2Former), and [DETR](https://github.com/facebookresearch/detr). Thanks for their excellent work!
|
VERSION
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
1.3.4.6
|
basicsr/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/xinntao/BasicSR
|
2 |
+
# flake8: noqa
|
3 |
+
from .archs import *
|
4 |
+
from .data import *
|
5 |
+
from .losses import *
|
6 |
+
from .metrics import *
|
7 |
+
from .models import *
|
8 |
+
# from .ops import *
|
9 |
+
# from .test import *
|
10 |
+
from .train import *
|
11 |
+
from .utils import *
|
12 |
+
try:
|
13 |
+
from .version import __gitsha__, __version__
|
14 |
+
except:
|
15 |
+
pass
|
basicsr/archs/__init__.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from copy import deepcopy
|
3 |
+
from os import path as osp
|
4 |
+
|
5 |
+
from basicsr.utils import get_root_logger, scandir
|
6 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
7 |
+
|
8 |
+
__all__ = ['build_network']
|
9 |
+
|
10 |
+
# automatically scan and import arch modules for registry
|
11 |
+
# scan all the files under the 'archs' folder and collect files ending with
|
12 |
+
# '_arch.py'
|
13 |
+
arch_folder = osp.dirname(osp.abspath(__file__))
|
14 |
+
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
|
15 |
+
# import all the arch modules
|
16 |
+
_arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames]
|
17 |
+
|
18 |
+
|
19 |
+
def build_network(opt):
|
20 |
+
opt = deepcopy(opt)
|
21 |
+
network_type = opt.pop('type')
|
22 |
+
net = ARCH_REGISTRY.get(network_type)(**opt)
|
23 |
+
logger = get_root_logger()
|
24 |
+
logger.info(f'Network [{net.__class__.__name__}] is created.')
|
25 |
+
return net
|
basicsr/archs/ddcolor_arch.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from basicsr.archs.ddcolor_arch_utils.unet import Hook, CustomPixelShuffle_ICNR, UnetBlockWide, NormType, custom_conv_layer
|
5 |
+
from basicsr.archs.ddcolor_arch_utils.convnext import ConvNeXt
|
6 |
+
from basicsr.archs.ddcolor_arch_utils.transformer_utils import SelfAttentionLayer, CrossAttentionLayer, FFNLayer, MLP
|
7 |
+
from basicsr.archs.ddcolor_arch_utils.position_encoding import PositionEmbeddingSine
|
8 |
+
from basicsr.archs.ddcolor_arch_utils.transformer import Transformer
|
9 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
10 |
+
|
11 |
+
|
12 |
+
@ARCH_REGISTRY.register()
|
13 |
+
class DDColor(nn.Module):
|
14 |
+
|
15 |
+
def __init__(self,
|
16 |
+
encoder_name='convnext-l',
|
17 |
+
decoder_name='MultiScaleColorDecoder',
|
18 |
+
num_input_channels=3,
|
19 |
+
input_size=(256, 256),
|
20 |
+
nf=512,
|
21 |
+
num_output_channels=3,
|
22 |
+
last_norm='Weight',
|
23 |
+
do_normalize=False,
|
24 |
+
num_queries=256,
|
25 |
+
num_scales=3,
|
26 |
+
dec_layers=9,
|
27 |
+
encoder_from_pretrain=False):
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
self.encoder = Encoder(encoder_name, ['norm0', 'norm1', 'norm2', 'norm3'], from_pretrain=encoder_from_pretrain)
|
31 |
+
self.encoder.eval()
|
32 |
+
test_input = torch.randn(1, num_input_channels, *input_size)
|
33 |
+
self.encoder(test_input)
|
34 |
+
|
35 |
+
self.decoder = Decoder(
|
36 |
+
self.encoder.hooks,
|
37 |
+
nf=nf,
|
38 |
+
last_norm=last_norm,
|
39 |
+
num_queries=num_queries,
|
40 |
+
num_scales=num_scales,
|
41 |
+
dec_layers=dec_layers,
|
42 |
+
decoder_name=decoder_name
|
43 |
+
)
|
44 |
+
self.refine_net = nn.Sequential(custom_conv_layer(num_queries + 3, num_output_channels, ks=1, use_activ=False, norm_type=NormType.Spectral))
|
45 |
+
|
46 |
+
self.do_normalize = do_normalize
|
47 |
+
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
48 |
+
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
49 |
+
|
50 |
+
def normalize(self, img):
|
51 |
+
return (img - self.mean) / self.std
|
52 |
+
|
53 |
+
def denormalize(self, img):
|
54 |
+
return img * self.std + self.mean
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
if x.shape[1] == 3:
|
58 |
+
x = self.normalize(x)
|
59 |
+
|
60 |
+
self.encoder(x)
|
61 |
+
out_feat = self.decoder()
|
62 |
+
coarse_input = torch.cat([out_feat, x], dim=1)
|
63 |
+
out = self.refine_net(coarse_input)
|
64 |
+
|
65 |
+
if self.do_normalize:
|
66 |
+
out = self.denormalize(out)
|
67 |
+
return out
|
68 |
+
|
69 |
+
|
70 |
+
class Decoder(nn.Module):
|
71 |
+
|
72 |
+
def __init__(self,
|
73 |
+
hooks,
|
74 |
+
nf=512,
|
75 |
+
blur=True,
|
76 |
+
last_norm='Weight',
|
77 |
+
num_queries=256,
|
78 |
+
num_scales=3,
|
79 |
+
dec_layers=9,
|
80 |
+
decoder_name='MultiScaleColorDecoder'):
|
81 |
+
super().__init__()
|
82 |
+
self.hooks = hooks
|
83 |
+
self.nf = nf
|
84 |
+
self.blur = blur
|
85 |
+
self.last_norm = getattr(NormType, last_norm)
|
86 |
+
self.decoder_name = decoder_name
|
87 |
+
|
88 |
+
self.layers = self.make_layers()
|
89 |
+
embed_dim = nf // 2
|
90 |
+
|
91 |
+
self.last_shuf = CustomPixelShuffle_ICNR(embed_dim, embed_dim, blur=self.blur, norm_type=self.last_norm, scale=4)
|
92 |
+
|
93 |
+
if self.decoder_name == 'MultiScaleColorDecoder':
|
94 |
+
self.color_decoder = MultiScaleColorDecoder(
|
95 |
+
in_channels=[512, 512, 256],
|
96 |
+
num_queries=num_queries,
|
97 |
+
num_scales=num_scales,
|
98 |
+
dec_layers=dec_layers,
|
99 |
+
)
|
100 |
+
else:
|
101 |
+
self.color_decoder = SingleColorDecoder(
|
102 |
+
in_channels=hooks[-1].feature.shape[1],
|
103 |
+
num_queries=num_queries,
|
104 |
+
)
|
105 |
+
|
106 |
+
|
107 |
+
def forward(self):
|
108 |
+
encode_feat = self.hooks[-1].feature
|
109 |
+
out0 = self.layers[0](encode_feat)
|
110 |
+
out1 = self.layers[1](out0)
|
111 |
+
out2 = self.layers[2](out1)
|
112 |
+
out3 = self.last_shuf(out2)
|
113 |
+
|
114 |
+
if self.decoder_name == 'MultiScaleColorDecoder':
|
115 |
+
out = self.color_decoder([out0, out1, out2], out3)
|
116 |
+
else:
|
117 |
+
out = self.color_decoder(out3, encode_feat)
|
118 |
+
|
119 |
+
return out
|
120 |
+
|
121 |
+
def make_layers(self):
|
122 |
+
decoder_layers = []
|
123 |
+
|
124 |
+
e_in_c = self.hooks[-1].feature.shape[1]
|
125 |
+
in_c = e_in_c
|
126 |
+
|
127 |
+
out_c = self.nf
|
128 |
+
setup_hooks = self.hooks[-2::-1]
|
129 |
+
for layer_index, hook in enumerate(setup_hooks):
|
130 |
+
feature_c = hook.feature.shape[1]
|
131 |
+
if layer_index == len(setup_hooks) - 1:
|
132 |
+
out_c = out_c // 2
|
133 |
+
decoder_layers.append(
|
134 |
+
UnetBlockWide(
|
135 |
+
in_c, feature_c, out_c, hook, blur=self.blur, self_attention=False, norm_type=NormType.Spectral))
|
136 |
+
in_c = out_c
|
137 |
+
return nn.Sequential(*decoder_layers)
|
138 |
+
|
139 |
+
|
140 |
+
class Encoder(nn.Module):
|
141 |
+
|
142 |
+
def __init__(self, encoder_name, hook_names, from_pretrain, **kwargs):
|
143 |
+
super().__init__()
|
144 |
+
|
145 |
+
if encoder_name == 'convnext-t' or encoder_name == 'convnext':
|
146 |
+
self.arch = ConvNeXt()
|
147 |
+
elif encoder_name == 'convnext-s':
|
148 |
+
self.arch = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768])
|
149 |
+
elif encoder_name == 'convnext-b':
|
150 |
+
self.arch = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024])
|
151 |
+
elif encoder_name == 'convnext-l':
|
152 |
+
self.arch = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536])
|
153 |
+
else:
|
154 |
+
raise NotImplementedError
|
155 |
+
|
156 |
+
self.encoder_name = encoder_name
|
157 |
+
self.hook_names = hook_names
|
158 |
+
self.hooks = self.setup_hooks()
|
159 |
+
|
160 |
+
if from_pretrain:
|
161 |
+
self.load_pretrain_model()
|
162 |
+
|
163 |
+
def setup_hooks(self):
|
164 |
+
hooks = [Hook(self.arch._modules[name]) for name in self.hook_names]
|
165 |
+
return hooks
|
166 |
+
|
167 |
+
def forward(self, x):
|
168 |
+
return self.arch(x)
|
169 |
+
|
170 |
+
def load_pretrain_model(self):
|
171 |
+
if self.encoder_name == 'convnext-t' or self.encoder_name == 'convnext':
|
172 |
+
self.load('pretrain/convnext_tiny_22k_224.pth')
|
173 |
+
elif self.encoder_name == 'convnext-s':
|
174 |
+
self.load('pretrain/convnext_small_22k_224.pth')
|
175 |
+
elif self.encoder_name == 'convnext-b':
|
176 |
+
self.load('pretrain/convnext_base_22k_224.pth')
|
177 |
+
elif self.encoder_name == 'convnext-l':
|
178 |
+
self.load('pretrain/convnext_large_22k_224.pth')
|
179 |
+
else:
|
180 |
+
raise NotImplementedError
|
181 |
+
print('Loaded pretrained convnext model.')
|
182 |
+
|
183 |
+
def load(self, path):
|
184 |
+
from basicsr.utils import get_root_logger
|
185 |
+
logger = get_root_logger()
|
186 |
+
if not path:
|
187 |
+
logger.info("No checkpoint found. Initializing model from scratch")
|
188 |
+
return
|
189 |
+
logger.info("[Encoder] Loading from {} ...".format(path))
|
190 |
+
checkpoint = torch.load(path, map_location=torch.device("cpu"))
|
191 |
+
checkpoint_state_dict = checkpoint['model'] if 'model' in checkpoint.keys() else checkpoint
|
192 |
+
incompatible = self.arch.load_state_dict(checkpoint_state_dict, strict=False)
|
193 |
+
|
194 |
+
if incompatible.missing_keys:
|
195 |
+
msg = "Some model parameters or buffers are not found in the checkpoint:\n"
|
196 |
+
msg += str(incompatible.missing_keys)
|
197 |
+
logger.warning(msg)
|
198 |
+
if incompatible.unexpected_keys:
|
199 |
+
msg = "The checkpoint state_dict contains keys that are not used by the model:\n"
|
200 |
+
msg += str(incompatible.unexpected_keys)
|
201 |
+
logger.warning(msg)
|
202 |
+
|
203 |
+
|
204 |
+
class MultiScaleColorDecoder(nn.Module):
|
205 |
+
|
206 |
+
def __init__(
|
207 |
+
self,
|
208 |
+
in_channels,
|
209 |
+
hidden_dim=256,
|
210 |
+
num_queries=100,
|
211 |
+
nheads=8,
|
212 |
+
dim_feedforward=2048,
|
213 |
+
dec_layers=9,
|
214 |
+
pre_norm=False,
|
215 |
+
color_embed_dim=256,
|
216 |
+
enforce_input_project=True,
|
217 |
+
num_scales=3
|
218 |
+
):
|
219 |
+
super().__init__()
|
220 |
+
|
221 |
+
# positional encoding
|
222 |
+
N_steps = hidden_dim // 2
|
223 |
+
self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
|
224 |
+
|
225 |
+
# define Transformer decoder here
|
226 |
+
self.num_heads = nheads
|
227 |
+
self.num_layers = dec_layers
|
228 |
+
self.transformer_self_attention_layers = nn.ModuleList()
|
229 |
+
self.transformer_cross_attention_layers = nn.ModuleList()
|
230 |
+
self.transformer_ffn_layers = nn.ModuleList()
|
231 |
+
|
232 |
+
for _ in range(self.num_layers):
|
233 |
+
self.transformer_self_attention_layers.append(
|
234 |
+
SelfAttentionLayer(
|
235 |
+
d_model=hidden_dim,
|
236 |
+
nhead=nheads,
|
237 |
+
dropout=0.0,
|
238 |
+
normalize_before=pre_norm,
|
239 |
+
)
|
240 |
+
)
|
241 |
+
self.transformer_cross_attention_layers.append(
|
242 |
+
CrossAttentionLayer(
|
243 |
+
d_model=hidden_dim,
|
244 |
+
nhead=nheads,
|
245 |
+
dropout=0.0,
|
246 |
+
normalize_before=pre_norm,
|
247 |
+
)
|
248 |
+
)
|
249 |
+
self.transformer_ffn_layers.append(
|
250 |
+
FFNLayer(
|
251 |
+
d_model=hidden_dim,
|
252 |
+
dim_feedforward=dim_feedforward,
|
253 |
+
dropout=0.0,
|
254 |
+
normalize_before=pre_norm,
|
255 |
+
)
|
256 |
+
)
|
257 |
+
|
258 |
+
self.decoder_norm = nn.LayerNorm(hidden_dim)
|
259 |
+
|
260 |
+
self.num_queries = num_queries
|
261 |
+
# learnable color query features
|
262 |
+
self.query_feat = nn.Embedding(num_queries, hidden_dim)
|
263 |
+
# learnable color query p.e.
|
264 |
+
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
265 |
+
|
266 |
+
# level embedding
|
267 |
+
self.num_feature_levels = num_scales
|
268 |
+
self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
|
269 |
+
|
270 |
+
# input projections
|
271 |
+
self.input_proj = nn.ModuleList()
|
272 |
+
for i in range(self.num_feature_levels):
|
273 |
+
if in_channels[i] != hidden_dim or enforce_input_project:
|
274 |
+
self.input_proj.append(nn.Conv2d(in_channels[i], hidden_dim, kernel_size=1))
|
275 |
+
nn.init.kaiming_uniform_(self.input_proj[-1].weight, a=1)
|
276 |
+
if self.input_proj[-1].bias is not None:
|
277 |
+
nn.init.constant_(self.input_proj[-1].bias, 0)
|
278 |
+
else:
|
279 |
+
self.input_proj.append(nn.Sequential())
|
280 |
+
|
281 |
+
# output FFNs
|
282 |
+
self.color_embed = MLP(hidden_dim, hidden_dim, color_embed_dim, 3)
|
283 |
+
|
284 |
+
def forward(self, x, img_features):
|
285 |
+
# x is a list of multi-scale feature
|
286 |
+
assert len(x) == self.num_feature_levels
|
287 |
+
src = []
|
288 |
+
pos = []
|
289 |
+
|
290 |
+
for i in range(self.num_feature_levels):
|
291 |
+
pos.append(self.pe_layer(x[i], None).flatten(2))
|
292 |
+
src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])
|
293 |
+
|
294 |
+
# flatten NxCxHxW to HWxNxC
|
295 |
+
pos[-1] = pos[-1].permute(2, 0, 1)
|
296 |
+
src[-1] = src[-1].permute(2, 0, 1)
|
297 |
+
|
298 |
+
_, bs, _ = src[0].shape
|
299 |
+
|
300 |
+
# QxNxC
|
301 |
+
query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
|
302 |
+
output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
|
303 |
+
|
304 |
+
for i in range(self.num_layers):
|
305 |
+
level_index = i % self.num_feature_levels
|
306 |
+
# attention: cross-attention first
|
307 |
+
output = self.transformer_cross_attention_layers[i](
|
308 |
+
output, src[level_index],
|
309 |
+
memory_mask=None,
|
310 |
+
memory_key_padding_mask=None,
|
311 |
+
pos=pos[level_index], query_pos=query_embed
|
312 |
+
)
|
313 |
+
output = self.transformer_self_attention_layers[i](
|
314 |
+
output, tgt_mask=None,
|
315 |
+
tgt_key_padding_mask=None,
|
316 |
+
query_pos=query_embed
|
317 |
+
)
|
318 |
+
# FFN
|
319 |
+
output = self.transformer_ffn_layers[i](
|
320 |
+
output
|
321 |
+
)
|
322 |
+
|
323 |
+
decoder_output = self.decoder_norm(output)
|
324 |
+
decoder_output = decoder_output.transpose(0, 1) # [N, bs, C] -> [bs, N, C]
|
325 |
+
color_embed = self.color_embed(decoder_output)
|
326 |
+
out = torch.einsum("bqc,bchw->bqhw", color_embed, img_features)
|
327 |
+
|
328 |
+
return out
|
329 |
+
|
330 |
+
|
331 |
+
class SingleColorDecoder(nn.Module):
|
332 |
+
|
333 |
+
def __init__(
|
334 |
+
self,
|
335 |
+
in_channels=768,
|
336 |
+
hidden_dim=256,
|
337 |
+
num_queries=256, # 100
|
338 |
+
nheads=8,
|
339 |
+
dropout=0.1,
|
340 |
+
dim_feedforward=2048,
|
341 |
+
enc_layers=0,
|
342 |
+
dec_layers=6,
|
343 |
+
pre_norm=False,
|
344 |
+
deep_supervision=True,
|
345 |
+
enforce_input_project=True,
|
346 |
+
):
|
347 |
+
|
348 |
+
super().__init__()
|
349 |
+
|
350 |
+
N_steps = hidden_dim // 2
|
351 |
+
self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
|
352 |
+
|
353 |
+
transformer = Transformer(
|
354 |
+
d_model=hidden_dim,
|
355 |
+
dropout=dropout,
|
356 |
+
nhead=nheads,
|
357 |
+
dim_feedforward=dim_feedforward,
|
358 |
+
num_encoder_layers=enc_layers,
|
359 |
+
num_decoder_layers=dec_layers,
|
360 |
+
normalize_before=pre_norm,
|
361 |
+
return_intermediate_dec=deep_supervision,
|
362 |
+
)
|
363 |
+
self.num_queries = num_queries
|
364 |
+
self.transformer = transformer
|
365 |
+
|
366 |
+
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
367 |
+
|
368 |
+
if in_channels != hidden_dim or enforce_input_project:
|
369 |
+
self.input_proj = nn.Conv2d(in_channels, hidden_dim, kernel_size=1)
|
370 |
+
nn.init.kaiming_uniform_(self.input_proj.weight, a=1)
|
371 |
+
if self.input_proj.bias is not None:
|
372 |
+
nn.init.constant_(self.input_proj.bias, 0)
|
373 |
+
else:
|
374 |
+
self.input_proj = nn.Sequential()
|
375 |
+
|
376 |
+
|
377 |
+
def forward(self, img_features, encode_feat):
|
378 |
+
pos = self.pe_layer(encode_feat)
|
379 |
+
src = encode_feat
|
380 |
+
mask = None
|
381 |
+
hs, memory = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos)
|
382 |
+
color_embed = hs[-1]
|
383 |
+
color_preds = torch.einsum('bqc,bchw->bqhw', color_embed, img_features)
|
384 |
+
return color_preds
|
385 |
+
|
basicsr/archs/ddcolor_arch_utils/__int__.py
ADDED
File without changes
|
basicsr/archs/ddcolor_arch_utils/convnext.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
# All rights reserved.
|
4 |
+
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from timm.models.layers import trunc_normal_, DropPath
|
13 |
+
|
14 |
+
class Block(nn.Module):
|
15 |
+
r""" ConvNeXt Block. There are two equivalent implementations:
|
16 |
+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
17 |
+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
18 |
+
We use (2) as we find it slightly faster in PyTorch
|
19 |
+
|
20 |
+
Args:
|
21 |
+
dim (int): Number of input channels.
|
22 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
23 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
24 |
+
"""
|
25 |
+
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
|
26 |
+
super().__init__()
|
27 |
+
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
|
28 |
+
self.norm = LayerNorm(dim, eps=1e-6)
|
29 |
+
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
|
30 |
+
self.act = nn.GELU()
|
31 |
+
self.pwconv2 = nn.Linear(4 * dim, dim)
|
32 |
+
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
|
33 |
+
requires_grad=True) if layer_scale_init_value > 0 else None
|
34 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
input = x
|
38 |
+
x = self.dwconv(x)
|
39 |
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
40 |
+
x = self.norm(x)
|
41 |
+
x = self.pwconv1(x)
|
42 |
+
x = self.act(x)
|
43 |
+
x = self.pwconv2(x)
|
44 |
+
if self.gamma is not None:
|
45 |
+
x = self.gamma * x
|
46 |
+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
47 |
+
|
48 |
+
x = input + self.drop_path(x)
|
49 |
+
return x
|
50 |
+
|
51 |
+
class ConvNeXt(nn.Module):
|
52 |
+
r""" ConvNeXt
|
53 |
+
A PyTorch impl of : `A ConvNet for the 2020s` -
|
54 |
+
https://arxiv.org/pdf/2201.03545.pdf
|
55 |
+
Args:
|
56 |
+
in_chans (int): Number of input image channels. Default: 3
|
57 |
+
num_classes (int): Number of classes for classification head. Default: 1000
|
58 |
+
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
|
59 |
+
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
|
60 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.
|
61 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
62 |
+
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
|
63 |
+
"""
|
64 |
+
def __init__(self, in_chans=3, num_classes=1000,
|
65 |
+
depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
|
66 |
+
layer_scale_init_value=1e-6, head_init_scale=1.,
|
67 |
+
):
|
68 |
+
super().__init__()
|
69 |
+
|
70 |
+
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
|
71 |
+
stem = nn.Sequential(
|
72 |
+
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
|
73 |
+
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
|
74 |
+
)
|
75 |
+
self.downsample_layers.append(stem)
|
76 |
+
for i in range(3):
|
77 |
+
downsample_layer = nn.Sequential(
|
78 |
+
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
79 |
+
nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
|
80 |
+
)
|
81 |
+
self.downsample_layers.append(downsample_layer)
|
82 |
+
|
83 |
+
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
|
84 |
+
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
85 |
+
cur = 0
|
86 |
+
for i in range(4):
|
87 |
+
stage = nn.Sequential(
|
88 |
+
*[Block(dim=dims[i], drop_path=dp_rates[cur + j],
|
89 |
+
layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
|
90 |
+
)
|
91 |
+
self.stages.append(stage)
|
92 |
+
cur += depths[i]
|
93 |
+
|
94 |
+
# add norm layers for each output
|
95 |
+
out_indices = (0, 1, 2, 3)
|
96 |
+
for i in out_indices:
|
97 |
+
layer = LayerNorm(dims[i], eps=1e-6, data_format="channels_first")
|
98 |
+
# layer = nn.Identity()
|
99 |
+
layer_name = f'norm{i}'
|
100 |
+
self.add_module(layer_name, layer)
|
101 |
+
|
102 |
+
self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
|
103 |
+
# self.head_cls = nn.Linear(dims[-1], 4)
|
104 |
+
|
105 |
+
self.apply(self._init_weights)
|
106 |
+
# self.head_cls.weight.data.mul_(head_init_scale)
|
107 |
+
# self.head_cls.bias.data.mul_(head_init_scale)
|
108 |
+
|
109 |
+
def _init_weights(self, m):
|
110 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
111 |
+
trunc_normal_(m.weight, std=.02)
|
112 |
+
nn.init.constant_(m.bias, 0)
|
113 |
+
|
114 |
+
def forward_features(self, x):
|
115 |
+
for i in range(4):
|
116 |
+
x = self.downsample_layers[i](x)
|
117 |
+
x = self.stages[i](x)
|
118 |
+
|
119 |
+
# add extra norm
|
120 |
+
norm_layer = getattr(self, f'norm{i}')
|
121 |
+
# x = norm_layer(x)
|
122 |
+
norm_layer(x)
|
123 |
+
|
124 |
+
return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
|
125 |
+
|
126 |
+
def forward(self, x):
|
127 |
+
x = self.forward_features(x)
|
128 |
+
# x = self.head_cls(x)
|
129 |
+
return x
|
130 |
+
|
131 |
+
class LayerNorm(nn.Module):
|
132 |
+
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
133 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
134 |
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
135 |
+
with shape (batch_size, channels, height, width).
|
136 |
+
"""
|
137 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
138 |
+
super().__init__()
|
139 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
140 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
141 |
+
self.eps = eps
|
142 |
+
self.data_format = data_format
|
143 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
144 |
+
raise NotImplementedError
|
145 |
+
self.normalized_shape = (normalized_shape, )
|
146 |
+
|
147 |
+
def forward(self, x):
|
148 |
+
if self.data_format == "channels_last": # B H W C
|
149 |
+
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
150 |
+
elif self.data_format == "channels_first": # B C H W
|
151 |
+
u = x.mean(1, keepdim=True)
|
152 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
153 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
154 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
155 |
+
return x
|
basicsr/archs/ddcolor_arch_utils/position_encoding.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Modified from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py
|
3 |
+
"""
|
4 |
+
Various positional encodings for the transformer.
|
5 |
+
"""
|
6 |
+
import math
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
|
12 |
+
class PositionEmbeddingSine(nn.Module):
|
13 |
+
"""
|
14 |
+
This is a more standard version of the position embedding, very similar to the one
|
15 |
+
used by the Attention is all you need paper, generalized to work on images.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
19 |
+
super().__init__()
|
20 |
+
self.num_pos_feats = num_pos_feats
|
21 |
+
self.temperature = temperature
|
22 |
+
self.normalize = normalize
|
23 |
+
if scale is not None and normalize is False:
|
24 |
+
raise ValueError("normalize should be True if scale is passed")
|
25 |
+
if scale is None:
|
26 |
+
scale = 2 * math.pi
|
27 |
+
self.scale = scale
|
28 |
+
|
29 |
+
def forward(self, x, mask=None):
|
30 |
+
if mask is None:
|
31 |
+
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
|
32 |
+
not_mask = ~mask
|
33 |
+
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
34 |
+
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
35 |
+
if self.normalize:
|
36 |
+
eps = 1e-6
|
37 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
38 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
39 |
+
|
40 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
41 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
42 |
+
|
43 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
44 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
45 |
+
pos_x = torch.stack(
|
46 |
+
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
47 |
+
).flatten(3)
|
48 |
+
pos_y = torch.stack(
|
49 |
+
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
50 |
+
).flatten(3)
|
51 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
52 |
+
return pos
|
basicsr/archs/ddcolor_arch_utils/transformer.py
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Modified from: https://github.com/facebookresearch/detr/blob/master/models/transformer.py
|
3 |
+
"""
|
4 |
+
Transformer class.
|
5 |
+
Copy-paste from torch.nn.Transformer with modifications:
|
6 |
+
* positional encodings are passed in MHattention
|
7 |
+
* extra LN at the end of encoder is removed
|
8 |
+
* decoder returns a stack of activations from all decoding layers
|
9 |
+
"""
|
10 |
+
import copy
|
11 |
+
from typing import List, Optional
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from torch import Tensor, nn
|
16 |
+
|
17 |
+
|
18 |
+
class Transformer(nn.Module):
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
d_model=512,
|
22 |
+
nhead=8,
|
23 |
+
num_encoder_layers=6,
|
24 |
+
num_decoder_layers=6,
|
25 |
+
dim_feedforward=2048,
|
26 |
+
dropout=0.1,
|
27 |
+
activation="relu",
|
28 |
+
normalize_before=False,
|
29 |
+
return_intermediate_dec=False,
|
30 |
+
):
|
31 |
+
super().__init__()
|
32 |
+
|
33 |
+
encoder_layer = TransformerEncoderLayer(
|
34 |
+
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
35 |
+
)
|
36 |
+
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
37 |
+
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
38 |
+
|
39 |
+
decoder_layer = TransformerDecoderLayer(
|
40 |
+
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
41 |
+
)
|
42 |
+
decoder_norm = nn.LayerNorm(d_model)
|
43 |
+
self.decoder = TransformerDecoder(
|
44 |
+
decoder_layer,
|
45 |
+
num_decoder_layers,
|
46 |
+
decoder_norm,
|
47 |
+
return_intermediate=return_intermediate_dec,
|
48 |
+
)
|
49 |
+
|
50 |
+
self._reset_parameters()
|
51 |
+
|
52 |
+
self.d_model = d_model
|
53 |
+
self.nhead = nhead
|
54 |
+
|
55 |
+
def _reset_parameters(self):
|
56 |
+
for p in self.parameters():
|
57 |
+
if p.dim() > 1:
|
58 |
+
nn.init.xavier_uniform_(p)
|
59 |
+
|
60 |
+
def forward(self, src, mask, query_embed, pos_embed):
|
61 |
+
# flatten NxCxHxW to HWxNxC
|
62 |
+
bs, c, h, w = src.shape
|
63 |
+
src = src.flatten(2).permute(2, 0, 1)
|
64 |
+
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
|
65 |
+
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
66 |
+
if mask is not None:
|
67 |
+
mask = mask.flatten(1)
|
68 |
+
|
69 |
+
tgt = torch.zeros_like(query_embed)
|
70 |
+
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
71 |
+
hs = self.decoder(
|
72 |
+
tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed
|
73 |
+
)
|
74 |
+
return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
|
75 |
+
|
76 |
+
|
77 |
+
class TransformerEncoder(nn.Module):
|
78 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
79 |
+
super().__init__()
|
80 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
81 |
+
self.num_layers = num_layers
|
82 |
+
self.norm = norm
|
83 |
+
|
84 |
+
def forward(
|
85 |
+
self,
|
86 |
+
src,
|
87 |
+
mask: Optional[Tensor] = None,
|
88 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
89 |
+
pos: Optional[Tensor] = None,
|
90 |
+
):
|
91 |
+
output = src
|
92 |
+
|
93 |
+
for layer in self.layers:
|
94 |
+
output = layer(
|
95 |
+
output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos
|
96 |
+
)
|
97 |
+
|
98 |
+
if self.norm is not None:
|
99 |
+
output = self.norm(output)
|
100 |
+
|
101 |
+
return output
|
102 |
+
|
103 |
+
|
104 |
+
class TransformerDecoder(nn.Module):
|
105 |
+
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
106 |
+
super().__init__()
|
107 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
108 |
+
self.num_layers = num_layers
|
109 |
+
self.norm = norm
|
110 |
+
self.return_intermediate = return_intermediate
|
111 |
+
|
112 |
+
def forward(
|
113 |
+
self,
|
114 |
+
tgt,
|
115 |
+
memory,
|
116 |
+
tgt_mask: Optional[Tensor] = None,
|
117 |
+
memory_mask: Optional[Tensor] = None,
|
118 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
119 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
120 |
+
pos: Optional[Tensor] = None,
|
121 |
+
query_pos: Optional[Tensor] = None,
|
122 |
+
):
|
123 |
+
output = tgt
|
124 |
+
|
125 |
+
intermediate = []
|
126 |
+
|
127 |
+
for layer in self.layers:
|
128 |
+
output = layer(
|
129 |
+
output,
|
130 |
+
memory,
|
131 |
+
tgt_mask=tgt_mask,
|
132 |
+
memory_mask=memory_mask,
|
133 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
134 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
135 |
+
pos=pos,
|
136 |
+
query_pos=query_pos,
|
137 |
+
)
|
138 |
+
if self.return_intermediate:
|
139 |
+
intermediate.append(self.norm(output))
|
140 |
+
|
141 |
+
if self.norm is not None:
|
142 |
+
output = self.norm(output)
|
143 |
+
if self.return_intermediate:
|
144 |
+
intermediate.pop()
|
145 |
+
intermediate.append(output)
|
146 |
+
|
147 |
+
if self.return_intermediate:
|
148 |
+
return torch.stack(intermediate)
|
149 |
+
|
150 |
+
return output.unsqueeze(0)
|
151 |
+
|
152 |
+
|
153 |
+
class TransformerEncoderLayer(nn.Module):
|
154 |
+
def __init__(
|
155 |
+
self,
|
156 |
+
d_model,
|
157 |
+
nhead,
|
158 |
+
dim_feedforward=2048,
|
159 |
+
dropout=0.1,
|
160 |
+
activation="relu",
|
161 |
+
normalize_before=False,
|
162 |
+
):
|
163 |
+
super().__init__()
|
164 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
165 |
+
# Implementation of Feedforward model
|
166 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
167 |
+
self.dropout = nn.Dropout(dropout)
|
168 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
169 |
+
|
170 |
+
self.norm1 = nn.LayerNorm(d_model)
|
171 |
+
self.norm2 = nn.LayerNorm(d_model)
|
172 |
+
self.dropout1 = nn.Dropout(dropout)
|
173 |
+
self.dropout2 = nn.Dropout(dropout)
|
174 |
+
|
175 |
+
self.activation = _get_activation_fn(activation)
|
176 |
+
self.normalize_before = normalize_before
|
177 |
+
|
178 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
179 |
+
return tensor if pos is None else tensor + pos
|
180 |
+
|
181 |
+
def forward_post(
|
182 |
+
self,
|
183 |
+
src,
|
184 |
+
src_mask: Optional[Tensor] = None,
|
185 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
186 |
+
pos: Optional[Tensor] = None,
|
187 |
+
):
|
188 |
+
q = k = self.with_pos_embed(src, pos)
|
189 |
+
src2 = self.self_attn(
|
190 |
+
q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
|
191 |
+
)[0]
|
192 |
+
src = src + self.dropout1(src2)
|
193 |
+
src = self.norm1(src)
|
194 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
195 |
+
src = src + self.dropout2(src2)
|
196 |
+
src = self.norm2(src)
|
197 |
+
return src
|
198 |
+
|
199 |
+
def forward_pre(
|
200 |
+
self,
|
201 |
+
src,
|
202 |
+
src_mask: Optional[Tensor] = None,
|
203 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
204 |
+
pos: Optional[Tensor] = None,
|
205 |
+
):
|
206 |
+
src2 = self.norm1(src)
|
207 |
+
q = k = self.with_pos_embed(src2, pos)
|
208 |
+
src2 = self.self_attn(
|
209 |
+
q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
|
210 |
+
)[0]
|
211 |
+
src = src + self.dropout1(src2)
|
212 |
+
src2 = self.norm2(src)
|
213 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
214 |
+
src = src + self.dropout2(src2)
|
215 |
+
return src
|
216 |
+
|
217 |
+
def forward(
|
218 |
+
self,
|
219 |
+
src,
|
220 |
+
src_mask: Optional[Tensor] = None,
|
221 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
222 |
+
pos: Optional[Tensor] = None,
|
223 |
+
):
|
224 |
+
if self.normalize_before:
|
225 |
+
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
226 |
+
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
227 |
+
|
228 |
+
|
229 |
+
class TransformerDecoderLayer(nn.Module):
|
230 |
+
def __init__(
|
231 |
+
self,
|
232 |
+
d_model,
|
233 |
+
nhead,
|
234 |
+
dim_feedforward=2048,
|
235 |
+
dropout=0.1,
|
236 |
+
activation="relu",
|
237 |
+
normalize_before=False,
|
238 |
+
):
|
239 |
+
super().__init__()
|
240 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
241 |
+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
242 |
+
# Implementation of Feedforward model
|
243 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
244 |
+
self.dropout = nn.Dropout(dropout)
|
245 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
246 |
+
|
247 |
+
self.norm1 = nn.LayerNorm(d_model)
|
248 |
+
self.norm2 = nn.LayerNorm(d_model)
|
249 |
+
self.norm3 = nn.LayerNorm(d_model)
|
250 |
+
self.dropout1 = nn.Dropout(dropout)
|
251 |
+
self.dropout2 = nn.Dropout(dropout)
|
252 |
+
self.dropout3 = nn.Dropout(dropout)
|
253 |
+
|
254 |
+
self.activation = _get_activation_fn(activation)
|
255 |
+
self.normalize_before = normalize_before
|
256 |
+
|
257 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
258 |
+
return tensor if pos is None else tensor + pos
|
259 |
+
|
260 |
+
def forward_post(
|
261 |
+
self,
|
262 |
+
tgt,
|
263 |
+
memory,
|
264 |
+
tgt_mask: Optional[Tensor] = None,
|
265 |
+
memory_mask: Optional[Tensor] = None,
|
266 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
267 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
268 |
+
pos: Optional[Tensor] = None,
|
269 |
+
query_pos: Optional[Tensor] = None,
|
270 |
+
):
|
271 |
+
q = k = self.with_pos_embed(tgt, query_pos)
|
272 |
+
tgt2 = self.self_attn(
|
273 |
+
q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
274 |
+
)[0]
|
275 |
+
tgt = tgt + self.dropout1(tgt2)
|
276 |
+
tgt = self.norm1(tgt)
|
277 |
+
tgt2 = self.multihead_attn(
|
278 |
+
query=self.with_pos_embed(tgt, query_pos),
|
279 |
+
key=self.with_pos_embed(memory, pos),
|
280 |
+
value=memory,
|
281 |
+
attn_mask=memory_mask,
|
282 |
+
key_padding_mask=memory_key_padding_mask,
|
283 |
+
)[0]
|
284 |
+
tgt = tgt + self.dropout2(tgt2)
|
285 |
+
tgt = self.norm2(tgt)
|
286 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
287 |
+
tgt = tgt + self.dropout3(tgt2)
|
288 |
+
tgt = self.norm3(tgt)
|
289 |
+
return tgt
|
290 |
+
|
291 |
+
def forward_pre(
|
292 |
+
self,
|
293 |
+
tgt,
|
294 |
+
memory,
|
295 |
+
tgt_mask: Optional[Tensor] = None,
|
296 |
+
memory_mask: Optional[Tensor] = None,
|
297 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
298 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
299 |
+
pos: Optional[Tensor] = None,
|
300 |
+
query_pos: Optional[Tensor] = None,
|
301 |
+
):
|
302 |
+
tgt2 = self.norm1(tgt)
|
303 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
304 |
+
tgt2 = self.self_attn(
|
305 |
+
q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
306 |
+
)[0]
|
307 |
+
tgt = tgt + self.dropout1(tgt2)
|
308 |
+
tgt2 = self.norm2(tgt)
|
309 |
+
tgt2 = self.multihead_attn(
|
310 |
+
query=self.with_pos_embed(tgt2, query_pos),
|
311 |
+
key=self.with_pos_embed(memory, pos),
|
312 |
+
value=memory,
|
313 |
+
attn_mask=memory_mask,
|
314 |
+
key_padding_mask=memory_key_padding_mask,
|
315 |
+
)[0]
|
316 |
+
tgt = tgt + self.dropout2(tgt2)
|
317 |
+
tgt2 = self.norm3(tgt)
|
318 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
319 |
+
tgt = tgt + self.dropout3(tgt2)
|
320 |
+
return tgt
|
321 |
+
|
322 |
+
def forward(
|
323 |
+
self,
|
324 |
+
tgt,
|
325 |
+
memory,
|
326 |
+
tgt_mask: Optional[Tensor] = None,
|
327 |
+
memory_mask: Optional[Tensor] = None,
|
328 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
329 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
330 |
+
pos: Optional[Tensor] = None,
|
331 |
+
query_pos: Optional[Tensor] = None,
|
332 |
+
):
|
333 |
+
if self.normalize_before:
|
334 |
+
return self.forward_pre(
|
335 |
+
tgt,
|
336 |
+
memory,
|
337 |
+
tgt_mask,
|
338 |
+
memory_mask,
|
339 |
+
tgt_key_padding_mask,
|
340 |
+
memory_key_padding_mask,
|
341 |
+
pos,
|
342 |
+
query_pos,
|
343 |
+
)
|
344 |
+
return self.forward_post(
|
345 |
+
tgt,
|
346 |
+
memory,
|
347 |
+
tgt_mask,
|
348 |
+
memory_mask,
|
349 |
+
tgt_key_padding_mask,
|
350 |
+
memory_key_padding_mask,
|
351 |
+
pos,
|
352 |
+
query_pos,
|
353 |
+
)
|
354 |
+
|
355 |
+
|
356 |
+
def _get_clones(module, N):
|
357 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
358 |
+
|
359 |
+
|
360 |
+
def _get_activation_fn(activation):
|
361 |
+
"""Return an activation function given a string"""
|
362 |
+
if activation == "relu":
|
363 |
+
return F.relu
|
364 |
+
if activation == "gelu":
|
365 |
+
return F.gelu
|
366 |
+
if activation == "glu":
|
367 |
+
return F.glu
|
368 |
+
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
basicsr/archs/ddcolor_arch_utils/transformer_utils.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
from torch import nn, Tensor
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
class SelfAttentionLayer(nn.Module):
|
6 |
+
|
7 |
+
def __init__(self, d_model, nhead, dropout=0.0,
|
8 |
+
activation="relu", normalize_before=False):
|
9 |
+
super().__init__()
|
10 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
11 |
+
|
12 |
+
self.norm = nn.LayerNorm(d_model)
|
13 |
+
self.dropout = nn.Dropout(dropout)
|
14 |
+
|
15 |
+
self.activation = _get_activation_fn(activation)
|
16 |
+
self.normalize_before = normalize_before
|
17 |
+
|
18 |
+
self._reset_parameters()
|
19 |
+
|
20 |
+
def _reset_parameters(self):
|
21 |
+
for p in self.parameters():
|
22 |
+
if p.dim() > 1:
|
23 |
+
nn.init.xavier_uniform_(p)
|
24 |
+
|
25 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
26 |
+
return tensor if pos is None else tensor + pos
|
27 |
+
|
28 |
+
def forward_post(self, tgt,
|
29 |
+
tgt_mask: Optional[Tensor] = None,
|
30 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
31 |
+
query_pos: Optional[Tensor] = None):
|
32 |
+
q = k = self.with_pos_embed(tgt, query_pos)
|
33 |
+
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
|
34 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
35 |
+
tgt = tgt + self.dropout(tgt2)
|
36 |
+
tgt = self.norm(tgt)
|
37 |
+
|
38 |
+
return tgt
|
39 |
+
|
40 |
+
def forward_pre(self, tgt,
|
41 |
+
tgt_mask: Optional[Tensor] = None,
|
42 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
43 |
+
query_pos: Optional[Tensor] = None):
|
44 |
+
tgt2 = self.norm(tgt)
|
45 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
46 |
+
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
47 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
48 |
+
tgt = tgt + self.dropout(tgt2)
|
49 |
+
|
50 |
+
return tgt
|
51 |
+
|
52 |
+
def forward(self, tgt,
|
53 |
+
tgt_mask: Optional[Tensor] = None,
|
54 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
55 |
+
query_pos: Optional[Tensor] = None):
|
56 |
+
if self.normalize_before:
|
57 |
+
return self.forward_pre(tgt, tgt_mask,
|
58 |
+
tgt_key_padding_mask, query_pos)
|
59 |
+
return self.forward_post(tgt, tgt_mask,
|
60 |
+
tgt_key_padding_mask, query_pos)
|
61 |
+
|
62 |
+
|
63 |
+
class CrossAttentionLayer(nn.Module):
|
64 |
+
|
65 |
+
def __init__(self, d_model, nhead, dropout=0.0,
|
66 |
+
activation="relu", normalize_before=False):
|
67 |
+
super().__init__()
|
68 |
+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
69 |
+
|
70 |
+
self.norm = nn.LayerNorm(d_model)
|
71 |
+
self.dropout = nn.Dropout(dropout)
|
72 |
+
|
73 |
+
self.activation = _get_activation_fn(activation)
|
74 |
+
self.normalize_before = normalize_before
|
75 |
+
|
76 |
+
self._reset_parameters()
|
77 |
+
|
78 |
+
def _reset_parameters(self):
|
79 |
+
for p in self.parameters():
|
80 |
+
if p.dim() > 1:
|
81 |
+
nn.init.xavier_uniform_(p)
|
82 |
+
|
83 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
84 |
+
return tensor if pos is None else tensor + pos
|
85 |
+
|
86 |
+
def forward_post(self, tgt, memory,
|
87 |
+
memory_mask: Optional[Tensor] = None,
|
88 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
89 |
+
pos: Optional[Tensor] = None,
|
90 |
+
query_pos: Optional[Tensor] = None):
|
91 |
+
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
|
92 |
+
key=self.with_pos_embed(memory, pos),
|
93 |
+
value=memory, attn_mask=memory_mask,
|
94 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
95 |
+
tgt = tgt + self.dropout(tgt2)
|
96 |
+
tgt = self.norm(tgt)
|
97 |
+
|
98 |
+
return tgt
|
99 |
+
|
100 |
+
def forward_pre(self, tgt, memory,
|
101 |
+
memory_mask: Optional[Tensor] = None,
|
102 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
103 |
+
pos: Optional[Tensor] = None,
|
104 |
+
query_pos: Optional[Tensor] = None):
|
105 |
+
tgt2 = self.norm(tgt)
|
106 |
+
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
|
107 |
+
key=self.with_pos_embed(memory, pos),
|
108 |
+
value=memory, attn_mask=memory_mask,
|
109 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
110 |
+
tgt = tgt + self.dropout(tgt2)
|
111 |
+
|
112 |
+
return tgt
|
113 |
+
|
114 |
+
def forward(self, tgt, memory,
|
115 |
+
memory_mask: Optional[Tensor] = None,
|
116 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
117 |
+
pos: Optional[Tensor] = None,
|
118 |
+
query_pos: Optional[Tensor] = None):
|
119 |
+
if self.normalize_before:
|
120 |
+
return self.forward_pre(tgt, memory, memory_mask,
|
121 |
+
memory_key_padding_mask, pos, query_pos)
|
122 |
+
return self.forward_post(tgt, memory, memory_mask,
|
123 |
+
memory_key_padding_mask, pos, query_pos)
|
124 |
+
|
125 |
+
|
126 |
+
class FFNLayer(nn.Module):
|
127 |
+
|
128 |
+
def __init__(self, d_model, dim_feedforward=2048, dropout=0.0,
|
129 |
+
activation="relu", normalize_before=False):
|
130 |
+
super().__init__()
|
131 |
+
# Implementation of Feedforward model
|
132 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
133 |
+
self.dropout = nn.Dropout(dropout)
|
134 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
135 |
+
|
136 |
+
self.norm = nn.LayerNorm(d_model)
|
137 |
+
|
138 |
+
self.activation = _get_activation_fn(activation)
|
139 |
+
self.normalize_before = normalize_before
|
140 |
+
|
141 |
+
self._reset_parameters()
|
142 |
+
|
143 |
+
def _reset_parameters(self):
|
144 |
+
for p in self.parameters():
|
145 |
+
if p.dim() > 1:
|
146 |
+
nn.init.xavier_uniform_(p)
|
147 |
+
|
148 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
149 |
+
return tensor if pos is None else tensor + pos
|
150 |
+
|
151 |
+
def forward_post(self, tgt):
|
152 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
153 |
+
tgt = tgt + self.dropout(tgt2)
|
154 |
+
tgt = self.norm(tgt)
|
155 |
+
return tgt
|
156 |
+
|
157 |
+
def forward_pre(self, tgt):
|
158 |
+
tgt2 = self.norm(tgt)
|
159 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
160 |
+
tgt = tgt + self.dropout(tgt2)
|
161 |
+
return tgt
|
162 |
+
|
163 |
+
def forward(self, tgt):
|
164 |
+
if self.normalize_before:
|
165 |
+
return self.forward_pre(tgt)
|
166 |
+
return self.forward_post(tgt)
|
167 |
+
|
168 |
+
|
169 |
+
def _get_activation_fn(activation):
|
170 |
+
"""Return an activation function given a string"""
|
171 |
+
if activation == "relu":
|
172 |
+
return F.relu
|
173 |
+
if activation == "gelu":
|
174 |
+
return F.gelu
|
175 |
+
if activation == "glu":
|
176 |
+
return F.glu
|
177 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
178 |
+
|
179 |
+
|
180 |
+
class MLP(nn.Module):
|
181 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
182 |
+
|
183 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
184 |
+
super().__init__()
|
185 |
+
self.num_layers = num_layers
|
186 |
+
h = [hidden_dim] * (num_layers - 1)
|
187 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
188 |
+
|
189 |
+
def forward(self, x):
|
190 |
+
for i, layer in enumerate(self.layers):
|
191 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
192 |
+
return x
|
basicsr/archs/ddcolor_arch_utils/unet.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
import collections
|
6 |
+
|
7 |
+
|
8 |
+
NormType = Enum('NormType', 'Batch BatchZero Weight Spectral')
|
9 |
+
|
10 |
+
|
11 |
+
class Hook:
|
12 |
+
feature = None
|
13 |
+
|
14 |
+
def __init__(self, module):
|
15 |
+
self.hook = module.register_forward_hook(self.hook_fn)
|
16 |
+
|
17 |
+
def hook_fn(self, module, input, output):
|
18 |
+
if isinstance(output, torch.Tensor):
|
19 |
+
self.feature = output
|
20 |
+
elif isinstance(output, collections.OrderedDict):
|
21 |
+
self.feature = output['out']
|
22 |
+
|
23 |
+
def remove(self):
|
24 |
+
self.hook.remove()
|
25 |
+
|
26 |
+
|
27 |
+
class SelfAttention(nn.Module):
|
28 |
+
"Self attention layer for nd."
|
29 |
+
|
30 |
+
def __init__(self, n_channels: int):
|
31 |
+
super().__init__()
|
32 |
+
self.query = conv1d(n_channels, n_channels // 8)
|
33 |
+
self.key = conv1d(n_channels, n_channels // 8)
|
34 |
+
self.value = conv1d(n_channels, n_channels)
|
35 |
+
self.gamma = nn.Parameter(torch.tensor([0.]))
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
#Notation from https://arxiv.org/pdf/1805.08318.pdf
|
39 |
+
size = x.size()
|
40 |
+
x = x.view(*size[:2], -1)
|
41 |
+
f, g, h = self.query(x), self.key(x), self.value(x)
|
42 |
+
beta = F.softmax(torch.bmm(f.permute(0, 2, 1).contiguous(), g), dim=1)
|
43 |
+
o = self.gamma * torch.bmm(h, beta) + x
|
44 |
+
return o.view(*size).contiguous()
|
45 |
+
|
46 |
+
|
47 |
+
def batchnorm_2d(nf: int, norm_type: NormType = NormType.Batch):
|
48 |
+
"A batchnorm2d layer with `nf` features initialized depending on `norm_type`."
|
49 |
+
bn = nn.BatchNorm2d(nf)
|
50 |
+
with torch.no_grad():
|
51 |
+
bn.bias.fill_(1e-3)
|
52 |
+
bn.weight.fill_(0. if norm_type == NormType.BatchZero else 1.)
|
53 |
+
return bn
|
54 |
+
|
55 |
+
|
56 |
+
def init_default(m: nn.Module, func=nn.init.kaiming_normal_) -> None:
|
57 |
+
"Initialize `m` weights with `func` and set `bias` to 0."
|
58 |
+
if func:
|
59 |
+
if hasattr(m, 'weight'): func(m.weight)
|
60 |
+
if hasattr(m, 'bias') and hasattr(m.bias, 'data'): m.bias.data.fill_(0.)
|
61 |
+
return m
|
62 |
+
|
63 |
+
|
64 |
+
def icnr(x, scale=2, init=nn.init.kaiming_normal_):
|
65 |
+
"ICNR init of `x`, with `scale` and `init` function."
|
66 |
+
ni, nf, h, w = x.shape
|
67 |
+
ni2 = int(ni / (scale**2))
|
68 |
+
k = init(torch.zeros([ni2, nf, h, w])).transpose(0, 1)
|
69 |
+
k = k.contiguous().view(ni2, nf, -1)
|
70 |
+
k = k.repeat(1, 1, scale**2)
|
71 |
+
k = k.contiguous().view([nf, ni, h, w]).transpose(0, 1)
|
72 |
+
x.data.copy_(k)
|
73 |
+
|
74 |
+
|
75 |
+
def conv1d(ni: int, no: int, ks: int = 1, stride: int = 1, padding: int = 0, bias: bool = False):
|
76 |
+
"Create and initialize a `nn.Conv1d` layer with spectral normalization."
|
77 |
+
conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias)
|
78 |
+
nn.init.kaiming_normal_(conv.weight)
|
79 |
+
if bias: conv.bias.data.zero_()
|
80 |
+
return nn.utils.spectral_norm(conv)
|
81 |
+
|
82 |
+
|
83 |
+
def custom_conv_layer(
|
84 |
+
ni: int,
|
85 |
+
nf: int,
|
86 |
+
ks: int = 3,
|
87 |
+
stride: int = 1,
|
88 |
+
padding: int = None,
|
89 |
+
bias: bool = None,
|
90 |
+
is_1d: bool = False,
|
91 |
+
norm_type=NormType.Batch,
|
92 |
+
use_activ: bool = True,
|
93 |
+
transpose: bool = False,
|
94 |
+
init=nn.init.kaiming_normal_,
|
95 |
+
self_attention: bool = False,
|
96 |
+
extra_bn: bool = False,
|
97 |
+
):
|
98 |
+
"Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers."
|
99 |
+
if padding is None:
|
100 |
+
padding = (ks - 1) // 2 if not transpose else 0
|
101 |
+
bn = norm_type in (NormType.Batch, NormType.BatchZero) or extra_bn == True
|
102 |
+
if bias is None:
|
103 |
+
bias = not bn
|
104 |
+
conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d
|
105 |
+
conv = init_default(
|
106 |
+
conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding),
|
107 |
+
init,
|
108 |
+
)
|
109 |
+
|
110 |
+
if norm_type == NormType.Weight:
|
111 |
+
conv = nn.utils.weight_norm(conv)
|
112 |
+
elif norm_type == NormType.Spectral:
|
113 |
+
conv = nn.utils.spectral_norm(conv)
|
114 |
+
layers = [conv]
|
115 |
+
if use_activ:
|
116 |
+
layers.append(nn.ReLU(True))
|
117 |
+
if bn:
|
118 |
+
layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf))
|
119 |
+
if self_attention:
|
120 |
+
layers.append(SelfAttention(nf))
|
121 |
+
return nn.Sequential(*layers)
|
122 |
+
|
123 |
+
|
124 |
+
def conv_layer(ni: int,
|
125 |
+
nf: int,
|
126 |
+
ks: int = 3,
|
127 |
+
stride: int = 1,
|
128 |
+
padding: int = None,
|
129 |
+
bias: bool = None,
|
130 |
+
is_1d: bool = False,
|
131 |
+
norm_type=NormType.Batch,
|
132 |
+
use_activ: bool = True,
|
133 |
+
transpose: bool = False,
|
134 |
+
init=nn.init.kaiming_normal_,
|
135 |
+
self_attention: bool = False):
|
136 |
+
"Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers."
|
137 |
+
if padding is None: padding = (ks - 1) // 2 if not transpose else 0
|
138 |
+
bn = norm_type in (NormType.Batch, NormType.BatchZero)
|
139 |
+
if bias is None: bias = not bn
|
140 |
+
conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d
|
141 |
+
conv = init_default(conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding), init)
|
142 |
+
if norm_type == NormType.Weight: conv = nn.utils.weight_norm(conv)
|
143 |
+
elif norm_type == NormType.Spectral: conv = nn.utils.spectral_norm(conv)
|
144 |
+
layers = [conv]
|
145 |
+
if use_activ: layers.append(nn.ReLU(True))
|
146 |
+
if bn: layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf))
|
147 |
+
if self_attention: layers.append(SelfAttention(nf))
|
148 |
+
return nn.Sequential(*layers)
|
149 |
+
|
150 |
+
|
151 |
+
def _conv(ni: int, nf: int, ks: int = 3, stride: int = 1, **kwargs):
|
152 |
+
return conv_layer(ni, nf, ks=ks, stride=stride, norm_type=NormType.Spectral, **kwargs)
|
153 |
+
|
154 |
+
|
155 |
+
class CustomPixelShuffle_ICNR(nn.Module):
|
156 |
+
"Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init, and `weight_norm`."
|
157 |
+
|
158 |
+
def __init__(self,
|
159 |
+
ni: int,
|
160 |
+
nf: int = None,
|
161 |
+
scale: int = 2,
|
162 |
+
blur: bool = True,
|
163 |
+
norm_type=NormType.Spectral,
|
164 |
+
extra_bn=False):
|
165 |
+
super().__init__()
|
166 |
+
self.conv = custom_conv_layer(
|
167 |
+
ni, nf * (scale**2), ks=1, use_activ=False, norm_type=norm_type, extra_bn=extra_bn)
|
168 |
+
icnr(self.conv[0].weight)
|
169 |
+
self.shuf = nn.PixelShuffle(scale)
|
170 |
+
self.do_blur = blur
|
171 |
+
# Blurring over (h*w) kernel
|
172 |
+
# "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts"
|
173 |
+
# - https://arxiv.org/abs/1806.02658
|
174 |
+
self.pad = nn.ReplicationPad2d((1, 0, 1, 0))
|
175 |
+
self.blur = nn.AvgPool2d(2, stride=1)
|
176 |
+
self.relu = nn.ReLU(True)
|
177 |
+
|
178 |
+
def forward(self, x):
|
179 |
+
x = self.shuf(self.relu(self.conv(x)))
|
180 |
+
return self.blur(self.pad(x)) if self.do_blur else x
|
181 |
+
|
182 |
+
|
183 |
+
class UnetBlockWide(nn.Module):
|
184 |
+
"A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
|
185 |
+
|
186 |
+
def __init__(self,
|
187 |
+
up_in_c: int,
|
188 |
+
x_in_c: int,
|
189 |
+
n_out: int,
|
190 |
+
hook,
|
191 |
+
blur: bool = False,
|
192 |
+
self_attention: bool = False,
|
193 |
+
norm_type=NormType.Spectral):
|
194 |
+
super().__init__()
|
195 |
+
|
196 |
+
self.hook = hook
|
197 |
+
up_out = n_out
|
198 |
+
self.shuf = CustomPixelShuffle_ICNR(up_in_c, up_out, blur=blur, norm_type=norm_type, extra_bn=True)
|
199 |
+
self.bn = batchnorm_2d(x_in_c)
|
200 |
+
ni = up_out + x_in_c
|
201 |
+
self.conv = custom_conv_layer(ni, n_out, norm_type=norm_type, self_attention=self_attention, extra_bn=True)
|
202 |
+
self.relu = nn.ReLU()
|
203 |
+
|
204 |
+
def forward(self, up_in):
|
205 |
+
s = self.hook.feature
|
206 |
+
up_out = self.shuf(up_in)
|
207 |
+
cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
|
208 |
+
return self.conv(cat_x)
|
basicsr/archs/ddcolor_arch_utils/util.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from skimage import color
|
4 |
+
|
5 |
+
|
6 |
+
def rgb2lab(img_rgb):
|
7 |
+
img_lab = color.rgb2lab(img_rgb)
|
8 |
+
return img_lab[:, :, :1], img_lab[:, :, 1:]
|
9 |
+
|
10 |
+
|
11 |
+
def tensor_lab2rgb(labs, illuminant="D65", observer="2"):
|
12 |
+
"""
|
13 |
+
Args:
|
14 |
+
lab : (B, C, H, W)
|
15 |
+
Returns:
|
16 |
+
tuple : (B, C, H, W)
|
17 |
+
"""
|
18 |
+
illuminants = \
|
19 |
+
{"A": {'2': (1.098466069456375, 1, 0.3558228003436005),
|
20 |
+
'10': (1.111420406956693, 1, 0.3519978321919493)},
|
21 |
+
"D50": {'2': (0.9642119944211994, 1, 0.8251882845188288),
|
22 |
+
'10': (0.9672062750333777, 1, 0.8142801513128616)},
|
23 |
+
"D55": {'2': (0.956797052643698, 1, 0.9214805860173273),
|
24 |
+
'10': (0.9579665682254781, 1, 0.9092525159847462)},
|
25 |
+
"D65": {'2': (0.95047, 1., 1.08883), # This was: `lab_ref_white`
|
26 |
+
'10': (0.94809667673716, 1, 1.0730513595166162)},
|
27 |
+
"D75": {'2': (0.9497220898840717, 1, 1.226393520724154),
|
28 |
+
'10': (0.9441713925645873, 1, 1.2064272211720228)},
|
29 |
+
"E": {'2': (1.0, 1.0, 1.0),
|
30 |
+
'10': (1.0, 1.0, 1.0)}}
|
31 |
+
xyz_from_rgb = np.array([[0.412453, 0.357580, 0.180423], [0.212671, 0.715160, 0.072169],
|
32 |
+
[0.019334, 0.119193, 0.950227]])
|
33 |
+
|
34 |
+
rgb_from_xyz = np.array([[3.240481340, -0.96925495, 0.055646640], [-1.53715152, 1.875990000, -0.20404134],
|
35 |
+
[-0.49853633, 0.041555930, 1.057311070]])
|
36 |
+
B, C, H, W = labs.shape
|
37 |
+
arrs = labs.permute((0, 2, 3, 1)).contiguous() # (B, 3, H, W) -> (B, H, W, 3)
|
38 |
+
L, a, b = arrs[:, :, :, 0:1], arrs[:, :, :, 1:2], arrs[:, :, :, 2:]
|
39 |
+
y = (L + 16.) / 116.
|
40 |
+
x = (a / 500.) + y
|
41 |
+
z = y - (b / 200.)
|
42 |
+
invalid = z.data < 0
|
43 |
+
z[invalid] = 0
|
44 |
+
xyz = torch.cat([x, y, z], dim=3)
|
45 |
+
mask = xyz.data > 0.2068966
|
46 |
+
mask_xyz = xyz.clone()
|
47 |
+
mask_xyz[mask] = torch.pow(xyz[mask], 3.0)
|
48 |
+
mask_xyz[~mask] = (xyz[~mask] - 16.0 / 116.) / 7.787
|
49 |
+
xyz_ref_white = illuminants[illuminant][observer]
|
50 |
+
for i in range(C):
|
51 |
+
mask_xyz[:, :, :, i] = mask_xyz[:, :, :, i] * xyz_ref_white[i]
|
52 |
+
|
53 |
+
rgb_trans = torch.mm(mask_xyz.view(-1, 3), torch.from_numpy(rgb_from_xyz).type_as(xyz)).view(B, H, W, C)
|
54 |
+
rgb = rgb_trans.permute((0, 3, 1, 2)).contiguous()
|
55 |
+
mask = rgb.data > 0.0031308
|
56 |
+
mask_rgb = rgb.clone()
|
57 |
+
mask_rgb[mask] = 1.055 * torch.pow(rgb[mask], 1 / 2.4) - 0.055
|
58 |
+
mask_rgb[~mask] = rgb[~mask] * 12.92
|
59 |
+
neg_mask = mask_rgb.data < 0
|
60 |
+
large_mask = mask_rgb.data > 1
|
61 |
+
mask_rgb[neg_mask] = 0
|
62 |
+
mask_rgb[large_mask] = 1
|
63 |
+
return mask_rgb
|
basicsr/archs/discriminator_arch.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torchvision import models
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from basicsr.archs.ddcolor_arch_utils.unet import _conv
|
7 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
8 |
+
|
9 |
+
|
10 |
+
@ARCH_REGISTRY.register()
|
11 |
+
class DynamicUNetDiscriminator(nn.Module):
|
12 |
+
|
13 |
+
def __init__(self, n_channels: int = 3, nf: int = 256, n_blocks: int = 3):
|
14 |
+
super().__init__()
|
15 |
+
layers = [_conv(n_channels, nf, ks=4, stride=2)]
|
16 |
+
for i in range(n_blocks):
|
17 |
+
layers += [
|
18 |
+
_conv(nf, nf, ks=3, stride=1),
|
19 |
+
_conv(nf, nf * 2, ks=4, stride=2, self_attention=(i == 0)),
|
20 |
+
]
|
21 |
+
nf *= 2
|
22 |
+
layers += [_conv(nf, nf, ks=3, stride=1), _conv(nf, 1, ks=4, bias=False, padding=0, use_activ=False)]
|
23 |
+
self.layers = nn.Sequential(*layers)
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
out = self.layers(x)
|
27 |
+
out = out.view(out.size(0), -1)
|
28 |
+
return out
|
basicsr/archs/vgg_arch.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from collections import OrderedDict
|
4 |
+
from torch import nn as nn
|
5 |
+
from torchvision.models import vgg as vgg
|
6 |
+
|
7 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
8 |
+
|
9 |
+
VGG_PRETRAIN_PATH = {
|
10 |
+
'vgg19': './pretrain/vgg19-dcbb9e9d.pth',
|
11 |
+
'vgg16_bn': './pretrain/vgg16_bn-6c64b313.pth'
|
12 |
+
}
|
13 |
+
|
14 |
+
NAMES = {
|
15 |
+
'vgg11': [
|
16 |
+
'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
|
17 |
+
'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
|
18 |
+
'pool5'
|
19 |
+
],
|
20 |
+
'vgg13': [
|
21 |
+
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
|
22 |
+
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
|
23 |
+
'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
|
24 |
+
],
|
25 |
+
'vgg16': [
|
26 |
+
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
|
27 |
+
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
|
28 |
+
'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
|
29 |
+
'pool5'
|
30 |
+
],
|
31 |
+
'vgg19': [
|
32 |
+
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
|
33 |
+
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
|
34 |
+
'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
|
35 |
+
'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
|
36 |
+
]
|
37 |
+
}
|
38 |
+
|
39 |
+
|
40 |
+
def insert_bn(names):
|
41 |
+
"""Insert bn layer after each conv.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
names (list): The list of layer names.
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
list: The list of layer names with bn layers.
|
48 |
+
"""
|
49 |
+
names_bn = []
|
50 |
+
for name in names:
|
51 |
+
names_bn.append(name)
|
52 |
+
if 'conv' in name:
|
53 |
+
position = name.replace('conv', '')
|
54 |
+
names_bn.append('bn' + position)
|
55 |
+
return names_bn
|
56 |
+
|
57 |
+
|
58 |
+
@ARCH_REGISTRY.register()
|
59 |
+
class VGGFeatureExtractor(nn.Module):
|
60 |
+
"""VGG network for feature extraction.
|
61 |
+
|
62 |
+
In this implementation, we allow users to choose whether use normalization
|
63 |
+
in the input feature and the type of vgg network. Note that the pretrained
|
64 |
+
path must fit the vgg type.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
layer_name_list (list[str]): Forward function returns the corresponding
|
68 |
+
features according to the layer_name_list.
|
69 |
+
Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
|
70 |
+
vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
|
71 |
+
use_input_norm (bool): If True, normalize the input image. Importantly,
|
72 |
+
the input feature must in the range [0, 1]. Default: True.
|
73 |
+
range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
|
74 |
+
Default: False.
|
75 |
+
requires_grad (bool): If true, the parameters of VGG network will be
|
76 |
+
optimized. Default: False.
|
77 |
+
remove_pooling (bool): If true, the max pooling operations in VGG net
|
78 |
+
will be removed. Default: False.
|
79 |
+
pooling_stride (int): The stride of max pooling operation. Default: 2.
|
80 |
+
"""
|
81 |
+
|
82 |
+
def __init__(self,
|
83 |
+
layer_name_list,
|
84 |
+
vgg_type='vgg19',
|
85 |
+
use_input_norm=True,
|
86 |
+
range_norm=False,
|
87 |
+
requires_grad=False,
|
88 |
+
remove_pooling=False,
|
89 |
+
pooling_stride=2):
|
90 |
+
super(VGGFeatureExtractor, self).__init__()
|
91 |
+
|
92 |
+
self.layer_name_list = layer_name_list
|
93 |
+
self.use_input_norm = use_input_norm
|
94 |
+
self.range_norm = range_norm
|
95 |
+
|
96 |
+
self.names = NAMES[vgg_type.replace('_bn', '')]
|
97 |
+
if 'bn' in vgg_type:
|
98 |
+
self.names = insert_bn(self.names)
|
99 |
+
|
100 |
+
# only borrow layers that will be used to avoid unused params
|
101 |
+
max_idx = 0
|
102 |
+
for v in layer_name_list:
|
103 |
+
idx = self.names.index(v)
|
104 |
+
if idx > max_idx:
|
105 |
+
max_idx = idx
|
106 |
+
|
107 |
+
if os.path.exists(VGG_PRETRAIN_PATH[vgg_type]):
|
108 |
+
vgg_net = getattr(vgg, vgg_type)(pretrained=False)
|
109 |
+
state_dict = torch.load(VGG_PRETRAIN_PATH[vgg_type], map_location=lambda storage, loc: storage)
|
110 |
+
vgg_net.load_state_dict(state_dict)
|
111 |
+
else:
|
112 |
+
vgg_net = getattr(vgg, vgg_type)(pretrained=True)
|
113 |
+
|
114 |
+
features = vgg_net.features[:max_idx + 1]
|
115 |
+
|
116 |
+
modified_net = OrderedDict()
|
117 |
+
for k, v in zip(self.names, features):
|
118 |
+
if 'pool' in k:
|
119 |
+
# if remove_pooling is true, pooling operation will be removed
|
120 |
+
if remove_pooling:
|
121 |
+
continue
|
122 |
+
else:
|
123 |
+
# in some cases, we may want to change the default stride
|
124 |
+
modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
|
125 |
+
else:
|
126 |
+
modified_net[k] = v
|
127 |
+
|
128 |
+
self.vgg_net = nn.Sequential(modified_net)
|
129 |
+
|
130 |
+
if not requires_grad:
|
131 |
+
self.vgg_net.eval()
|
132 |
+
for param in self.parameters():
|
133 |
+
param.requires_grad = False
|
134 |
+
else:
|
135 |
+
self.vgg_net.train()
|
136 |
+
for param in self.parameters():
|
137 |
+
param.requires_grad = True
|
138 |
+
|
139 |
+
if self.use_input_norm:
|
140 |
+
# the mean is for image with range [0, 1]
|
141 |
+
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
142 |
+
# the std is for image with range [0, 1]
|
143 |
+
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
144 |
+
|
145 |
+
def forward(self, x):
|
146 |
+
"""Forward function.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
x (Tensor): Input tensor with shape (n, c, h, w).
|
150 |
+
|
151 |
+
Returns:
|
152 |
+
Tensor: Forward results.
|
153 |
+
"""
|
154 |
+
if self.range_norm:
|
155 |
+
x = (x + 1) / 2
|
156 |
+
if self.use_input_norm:
|
157 |
+
x = (x - self.mean) / self.std
|
158 |
+
|
159 |
+
output = {}
|
160 |
+
for key, layer in self.vgg_net._modules.items():
|
161 |
+
x = layer(x)
|
162 |
+
if key in self.layer_name_list:
|
163 |
+
output[key] = x.clone()
|
164 |
+
|
165 |
+
return output
|
basicsr/data/__init__.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
import torch
|
5 |
+
import torch.utils.data
|
6 |
+
from copy import deepcopy
|
7 |
+
from functools import partial
|
8 |
+
from os import path as osp
|
9 |
+
|
10 |
+
from basicsr.data.prefetch_dataloader import PrefetchDataLoader
|
11 |
+
from basicsr.utils import get_root_logger, scandir
|
12 |
+
from basicsr.utils.dist_util import get_dist_info
|
13 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
14 |
+
|
15 |
+
__all__ = ['build_dataset', 'build_dataloader']
|
16 |
+
|
17 |
+
# automatically scan and import dataset modules for registry
|
18 |
+
# scan all the files under the data folder with '_dataset' in file names
|
19 |
+
data_folder = osp.dirname(osp.abspath(__file__))
|
20 |
+
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
|
21 |
+
# import all the dataset modules
|
22 |
+
_dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
|
23 |
+
|
24 |
+
|
25 |
+
def build_dataset(dataset_opt):
|
26 |
+
"""Build dataset from options.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
dataset_opt (dict): Configuration for dataset. It must contain:
|
30 |
+
name (str): Dataset name.
|
31 |
+
type (str): Dataset type.
|
32 |
+
"""
|
33 |
+
dataset_opt = deepcopy(dataset_opt)
|
34 |
+
dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
|
35 |
+
logger = get_root_logger()
|
36 |
+
logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.')
|
37 |
+
return dataset
|
38 |
+
|
39 |
+
|
40 |
+
def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
|
41 |
+
"""Build dataloader.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
45 |
+
dataset_opt (dict): Dataset options. It contains the following keys:
|
46 |
+
phase (str): 'train' or 'val'.
|
47 |
+
num_worker_per_gpu (int): Number of workers for each GPU.
|
48 |
+
batch_size_per_gpu (int): Training batch size for each GPU.
|
49 |
+
num_gpu (int): Number of GPUs. Used only in the train phase.
|
50 |
+
Default: 1.
|
51 |
+
dist (bool): Whether in distributed training. Used only in the train
|
52 |
+
phase. Default: False.
|
53 |
+
sampler (torch.utils.data.sampler): Data sampler. Default: None.
|
54 |
+
seed (int | None): Seed. Default: None
|
55 |
+
"""
|
56 |
+
phase = dataset_opt['phase']
|
57 |
+
rank, _ = get_dist_info()
|
58 |
+
if phase == 'train':
|
59 |
+
if dist: # distributed training
|
60 |
+
batch_size = dataset_opt['batch_size_per_gpu']
|
61 |
+
num_workers = dataset_opt['num_worker_per_gpu']
|
62 |
+
else: # non-distributed training
|
63 |
+
multiplier = 1 if num_gpu == 0 else num_gpu
|
64 |
+
batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
|
65 |
+
num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
|
66 |
+
dataloader_args = dict(
|
67 |
+
dataset=dataset,
|
68 |
+
batch_size=batch_size,
|
69 |
+
shuffle=False,
|
70 |
+
num_workers=num_workers,
|
71 |
+
sampler=sampler,
|
72 |
+
drop_last=True)
|
73 |
+
if sampler is None:
|
74 |
+
dataloader_args['shuffle'] = True
|
75 |
+
dataloader_args['worker_init_fn'] = partial(
|
76 |
+
worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
|
77 |
+
elif phase in ['val', 'test']: # validation
|
78 |
+
dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
|
79 |
+
else:
|
80 |
+
raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.")
|
81 |
+
|
82 |
+
dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
|
83 |
+
dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False)
|
84 |
+
|
85 |
+
prefetch_mode = dataset_opt.get('prefetch_mode')
|
86 |
+
if prefetch_mode == 'cpu': # CPUPrefetcher
|
87 |
+
num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
|
88 |
+
logger = get_root_logger()
|
89 |
+
logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}')
|
90 |
+
return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
|
91 |
+
else:
|
92 |
+
# prefetch_mode=None: Normal dataloader
|
93 |
+
# prefetch_mode='cuda': dataloader for CUDAPrefetcher
|
94 |
+
return torch.utils.data.DataLoader(**dataloader_args)
|
95 |
+
|
96 |
+
|
97 |
+
def worker_init_fn(worker_id, num_workers, rank, seed):
|
98 |
+
# Set the worker seed to num_workers * rank + worker_id + seed
|
99 |
+
worker_seed = num_workers * rank + worker_id + seed
|
100 |
+
np.random.seed(worker_seed)
|
101 |
+
random.seed(worker_seed)
|
basicsr/data/data_sampler.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch.utils.data.sampler import Sampler
|
4 |
+
|
5 |
+
|
6 |
+
class EnlargedSampler(Sampler):
|
7 |
+
"""Sampler that restricts data loading to a subset of the dataset.
|
8 |
+
|
9 |
+
Modified from torch.utils.data.distributed.DistributedSampler
|
10 |
+
Support enlarging the dataset for iteration-based training, for saving
|
11 |
+
time when restart the dataloader after each epoch
|
12 |
+
|
13 |
+
Args:
|
14 |
+
dataset (torch.utils.data.Dataset): Dataset used for sampling.
|
15 |
+
num_replicas (int | None): Number of processes participating in
|
16 |
+
the training. It is usually the world_size.
|
17 |
+
rank (int | None): Rank of the current process within num_replicas.
|
18 |
+
ratio (int): Enlarging ratio. Default: 1.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, dataset, num_replicas, rank, ratio=1):
|
22 |
+
self.dataset = dataset
|
23 |
+
self.num_replicas = num_replicas
|
24 |
+
self.rank = rank
|
25 |
+
self.epoch = 0
|
26 |
+
self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
|
27 |
+
self.total_size = self.num_samples * self.num_replicas
|
28 |
+
|
29 |
+
def __iter__(self):
|
30 |
+
# deterministically shuffle based on epoch
|
31 |
+
g = torch.Generator()
|
32 |
+
g.manual_seed(self.epoch)
|
33 |
+
indices = torch.randperm(self.total_size, generator=g).tolist()
|
34 |
+
|
35 |
+
dataset_size = len(self.dataset)
|
36 |
+
indices = [v % dataset_size for v in indices]
|
37 |
+
|
38 |
+
# subsample
|
39 |
+
indices = indices[self.rank:self.total_size:self.num_replicas]
|
40 |
+
assert len(indices) == self.num_samples
|
41 |
+
|
42 |
+
return iter(indices)
|
43 |
+
|
44 |
+
def __len__(self):
|
45 |
+
return self.num_samples
|
46 |
+
|
47 |
+
def set_epoch(self, epoch):
|
48 |
+
self.epoch = epoch
|
basicsr/data/data_util.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from os import path as osp
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
from basicsr.data.transforms import mod_crop
|
8 |
+
from basicsr.utils import img2tensor, scandir
|
9 |
+
|
10 |
+
|
11 |
+
def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
|
12 |
+
"""Read a sequence of images from a given folder path.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
path (list[str] | str): List of image paths or image folder path.
|
16 |
+
require_mod_crop (bool): Require mod crop for each image.
|
17 |
+
Default: False.
|
18 |
+
scale (int): Scale factor for mod_crop. Default: 1.
|
19 |
+
return_imgname(bool): Whether return image names. Default False.
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
Tensor: size (t, c, h, w), RGB, [0, 1].
|
23 |
+
list[str]: Returned image name list.
|
24 |
+
"""
|
25 |
+
if isinstance(path, list):
|
26 |
+
img_paths = path
|
27 |
+
else:
|
28 |
+
img_paths = sorted(list(scandir(path, full_path=True)))
|
29 |
+
imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
|
30 |
+
|
31 |
+
if require_mod_crop:
|
32 |
+
imgs = [mod_crop(img, scale) for img in imgs]
|
33 |
+
imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
|
34 |
+
imgs = torch.stack(imgs, dim=0)
|
35 |
+
|
36 |
+
if return_imgname:
|
37 |
+
imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths]
|
38 |
+
return imgs, imgnames
|
39 |
+
else:
|
40 |
+
return imgs
|
41 |
+
|
42 |
+
|
43 |
+
def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
|
44 |
+
"""Generate an index list for reading `num_frames` frames from a sequence
|
45 |
+
of images.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
crt_idx (int): Current center index.
|
49 |
+
max_frame_num (int): Max number of the sequence of images (from 1).
|
50 |
+
num_frames (int): Reading num_frames frames.
|
51 |
+
padding (str): Padding mode, one of
|
52 |
+
'replicate' | 'reflection' | 'reflection_circle' | 'circle'
|
53 |
+
Examples: current_idx = 0, num_frames = 5
|
54 |
+
The generated frame indices under different padding mode:
|
55 |
+
replicate: [0, 0, 0, 1, 2]
|
56 |
+
reflection: [2, 1, 0, 1, 2]
|
57 |
+
reflection_circle: [4, 3, 0, 1, 2]
|
58 |
+
circle: [3, 4, 0, 1, 2]
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
list[int]: A list of indices.
|
62 |
+
"""
|
63 |
+
assert num_frames % 2 == 1, 'num_frames should be an odd number.'
|
64 |
+
assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
|
65 |
+
|
66 |
+
max_frame_num = max_frame_num - 1 # start from 0
|
67 |
+
num_pad = num_frames // 2
|
68 |
+
|
69 |
+
indices = []
|
70 |
+
for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
|
71 |
+
if i < 0:
|
72 |
+
if padding == 'replicate':
|
73 |
+
pad_idx = 0
|
74 |
+
elif padding == 'reflection':
|
75 |
+
pad_idx = -i
|
76 |
+
elif padding == 'reflection_circle':
|
77 |
+
pad_idx = crt_idx + num_pad - i
|
78 |
+
else:
|
79 |
+
pad_idx = num_frames + i
|
80 |
+
elif i > max_frame_num:
|
81 |
+
if padding == 'replicate':
|
82 |
+
pad_idx = max_frame_num
|
83 |
+
elif padding == 'reflection':
|
84 |
+
pad_idx = max_frame_num * 2 - i
|
85 |
+
elif padding == 'reflection_circle':
|
86 |
+
pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
|
87 |
+
else:
|
88 |
+
pad_idx = i - num_frames
|
89 |
+
else:
|
90 |
+
pad_idx = i
|
91 |
+
indices.append(pad_idx)
|
92 |
+
return indices
|
93 |
+
|
94 |
+
|
95 |
+
def paired_paths_from_lmdb(folders, keys):
|
96 |
+
"""Generate paired paths from lmdb files.
|
97 |
+
|
98 |
+
Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
|
99 |
+
|
100 |
+
lq.lmdb
|
101 |
+
├── data.mdb
|
102 |
+
├── lock.mdb
|
103 |
+
├── meta_info.txt
|
104 |
+
|
105 |
+
The data.mdb and lock.mdb are standard lmdb files and you can refer to
|
106 |
+
https://lmdb.readthedocs.io/en/release/ for more details.
|
107 |
+
|
108 |
+
The meta_info.txt is a specified txt file to record the meta information
|
109 |
+
of our datasets. It will be automatically created when preparing
|
110 |
+
datasets by our provided dataset tools.
|
111 |
+
Each line in the txt file records
|
112 |
+
1)image name (with extension),
|
113 |
+
2)image shape,
|
114 |
+
3)compression level, separated by a white space.
|
115 |
+
Example: `baboon.png (120,125,3) 1`
|
116 |
+
|
117 |
+
We use the image name without extension as the lmdb key.
|
118 |
+
Note that we use the same key for the corresponding lq and gt images.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
folders (list[str]): A list of folder path. The order of list should
|
122 |
+
be [input_folder, gt_folder].
|
123 |
+
keys (list[str]): A list of keys identifying folders. The order should
|
124 |
+
be in consistent with folders, e.g., ['lq', 'gt'].
|
125 |
+
Note that this key is different from lmdb keys.
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
list[str]: Returned path list.
|
129 |
+
"""
|
130 |
+
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
131 |
+
f'But got {len(folders)}')
|
132 |
+
assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
|
133 |
+
input_folder, gt_folder = folders
|
134 |
+
input_key, gt_key = keys
|
135 |
+
|
136 |
+
if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
|
137 |
+
raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
|
138 |
+
f'formats. But received {input_key}: {input_folder}; '
|
139 |
+
f'{gt_key}: {gt_folder}')
|
140 |
+
# ensure that the two meta_info files are the same
|
141 |
+
with open(osp.join(input_folder, 'meta_info.txt')) as fin:
|
142 |
+
input_lmdb_keys = [line.split('.')[0] for line in fin]
|
143 |
+
with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
|
144 |
+
gt_lmdb_keys = [line.split('.')[0] for line in fin]
|
145 |
+
if set(input_lmdb_keys) != set(gt_lmdb_keys):
|
146 |
+
raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
|
147 |
+
else:
|
148 |
+
paths = []
|
149 |
+
for lmdb_key in sorted(input_lmdb_keys):
|
150 |
+
paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
|
151 |
+
return paths
|
152 |
+
|
153 |
+
|
154 |
+
def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
|
155 |
+
"""Generate paired paths from an meta information file.
|
156 |
+
|
157 |
+
Each line in the meta information file contains the image names and
|
158 |
+
image shape (usually for gt), separated by a white space.
|
159 |
+
|
160 |
+
Example of an meta information file:
|
161 |
+
```
|
162 |
+
0001_s001.png (480,480,3)
|
163 |
+
0001_s002.png (480,480,3)
|
164 |
+
```
|
165 |
+
|
166 |
+
Args:
|
167 |
+
folders (list[str]): A list of folder path. The order of list should
|
168 |
+
be [input_folder, gt_folder].
|
169 |
+
keys (list[str]): A list of keys identifying folders. The order should
|
170 |
+
be in consistent with folders, e.g., ['lq', 'gt'].
|
171 |
+
meta_info_file (str): Path to the meta information file.
|
172 |
+
filename_tmpl (str): Template for each filename. Note that the
|
173 |
+
template excludes the file extension. Usually the filename_tmpl is
|
174 |
+
for files in the input folder.
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
list[str]: Returned path list.
|
178 |
+
"""
|
179 |
+
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
180 |
+
f'But got {len(folders)}')
|
181 |
+
assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
|
182 |
+
input_folder, gt_folder = folders
|
183 |
+
input_key, gt_key = keys
|
184 |
+
|
185 |
+
with open(meta_info_file, 'r') as fin:
|
186 |
+
gt_names = [line.split(' ')[0] for line in fin]
|
187 |
+
|
188 |
+
paths = []
|
189 |
+
for gt_name in gt_names:
|
190 |
+
basename, ext = osp.splitext(osp.basename(gt_name))
|
191 |
+
input_name = f'{filename_tmpl.format(basename)}{ext}'
|
192 |
+
input_path = osp.join(input_folder, input_name)
|
193 |
+
gt_path = osp.join(gt_folder, gt_name)
|
194 |
+
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
|
195 |
+
return paths
|
196 |
+
|
197 |
+
|
198 |
+
def paired_paths_from_folder(folders, keys, filename_tmpl):
|
199 |
+
"""Generate paired paths from folders.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
folders (list[str]): A list of folder path. The order of list should
|
203 |
+
be [input_folder, gt_folder].
|
204 |
+
keys (list[str]): A list of keys identifying folders. The order should
|
205 |
+
be in consistent with folders, e.g., ['lq', 'gt'].
|
206 |
+
filename_tmpl (str): Template for each filename. Note that the
|
207 |
+
template excludes the file extension. Usually the filename_tmpl is
|
208 |
+
for files in the input folder.
|
209 |
+
|
210 |
+
Returns:
|
211 |
+
list[str]: Returned path list.
|
212 |
+
"""
|
213 |
+
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
214 |
+
f'But got {len(folders)}')
|
215 |
+
assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
|
216 |
+
input_folder, gt_folder = folders
|
217 |
+
input_key, gt_key = keys
|
218 |
+
|
219 |
+
input_paths = list(scandir(input_folder))
|
220 |
+
gt_paths = list(scandir(gt_folder))
|
221 |
+
assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
|
222 |
+
f'{len(input_paths)}, {len(gt_paths)}.')
|
223 |
+
paths = []
|
224 |
+
for gt_path in gt_paths:
|
225 |
+
basename, ext = osp.splitext(osp.basename(gt_path))
|
226 |
+
input_name = f'{filename_tmpl.format(basename)}{ext}'
|
227 |
+
input_path = osp.join(input_folder, input_name)
|
228 |
+
assert input_name in input_paths, f'{input_name} is not in {input_key}_paths.'
|
229 |
+
gt_path = osp.join(gt_folder, gt_path)
|
230 |
+
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
|
231 |
+
return paths
|
232 |
+
|
233 |
+
|
234 |
+
def paths_from_folder(folder):
|
235 |
+
"""Generate paths from folder.
|
236 |
+
|
237 |
+
Args:
|
238 |
+
folder (str): Folder path.
|
239 |
+
|
240 |
+
Returns:
|
241 |
+
list[str]: Returned path list.
|
242 |
+
"""
|
243 |
+
|
244 |
+
paths = list(scandir(folder))
|
245 |
+
paths = [osp.join(folder, path) for path in paths]
|
246 |
+
return paths
|
247 |
+
|
248 |
+
|
249 |
+
def paths_from_lmdb(folder):
|
250 |
+
"""Generate paths from lmdb.
|
251 |
+
|
252 |
+
Args:
|
253 |
+
folder (str): Folder path.
|
254 |
+
|
255 |
+
Returns:
|
256 |
+
list[str]: Returned path list.
|
257 |
+
"""
|
258 |
+
if not folder.endswith('.lmdb'):
|
259 |
+
raise ValueError(f'Folder {folder}folder should in lmdb format.')
|
260 |
+
with open(osp.join(folder, 'meta_info.txt')) as fin:
|
261 |
+
paths = [line.split('.')[0] for line in fin]
|
262 |
+
return paths
|
263 |
+
|
264 |
+
|
265 |
+
def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
|
266 |
+
"""Generate Gaussian kernel used in `duf_downsample`.
|
267 |
+
|
268 |
+
Args:
|
269 |
+
kernel_size (int): Kernel size. Default: 13.
|
270 |
+
sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
|
271 |
+
|
272 |
+
Returns:
|
273 |
+
np.array: The Gaussian kernel.
|
274 |
+
"""
|
275 |
+
from scipy.ndimage import filters as filters
|
276 |
+
kernel = np.zeros((kernel_size, kernel_size))
|
277 |
+
# set element at the middle to one, a dirac delta
|
278 |
+
kernel[kernel_size // 2, kernel_size // 2] = 1
|
279 |
+
# gaussian-smooth the dirac, resulting in a gaussian filter
|
280 |
+
return filters.gaussian_filter(kernel, sigma)
|
281 |
+
|
282 |
+
|
283 |
+
def duf_downsample(x, kernel_size=13, scale=4):
|
284 |
+
"""Downsamping with Gaussian kernel used in the DUF official code.
|
285 |
+
|
286 |
+
Args:
|
287 |
+
x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
|
288 |
+
kernel_size (int): Kernel size. Default: 13.
|
289 |
+
scale (int): Downsampling factor. Supported scale: (2, 3, 4).
|
290 |
+
Default: 4.
|
291 |
+
|
292 |
+
Returns:
|
293 |
+
Tensor: DUF downsampled frames.
|
294 |
+
"""
|
295 |
+
assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
|
296 |
+
|
297 |
+
squeeze_flag = False
|
298 |
+
if x.ndim == 4:
|
299 |
+
squeeze_flag = True
|
300 |
+
x = x.unsqueeze(0)
|
301 |
+
b, t, c, h, w = x.size()
|
302 |
+
x = x.view(-1, 1, h, w)
|
303 |
+
pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
|
304 |
+
x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
|
305 |
+
|
306 |
+
gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
|
307 |
+
gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
|
308 |
+
x = F.conv2d(x, gaussian_filter, stride=scale)
|
309 |
+
x = x[:, :, 2:-2, 2:-2]
|
310 |
+
x = x.view(b, t, c, x.size(2), x.size(3))
|
311 |
+
if squeeze_flag:
|
312 |
+
x = x.squeeze(0)
|
313 |
+
return x
|
basicsr/data/fmix.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Fmix paper from arxiv: https://arxiv.org/abs/2002.12047
|
3 |
+
Fmix code from github : https://github.com/ecs-vlc/FMix
|
4 |
+
'''
|
5 |
+
import math
|
6 |
+
import random
|
7 |
+
import numpy as np
|
8 |
+
from scipy.stats import beta
|
9 |
+
|
10 |
+
|
11 |
+
def fftfreqnd(h, w=None, z=None):
|
12 |
+
""" Get bin values for discrete fourier transform of size (h, w, z)
|
13 |
+
:param h: Required, first dimension size
|
14 |
+
:param w: Optional, second dimension size
|
15 |
+
:param z: Optional, third dimension size
|
16 |
+
"""
|
17 |
+
fz = fx = 0
|
18 |
+
fy = np.fft.fftfreq(h)
|
19 |
+
|
20 |
+
if w is not None:
|
21 |
+
fy = np.expand_dims(fy, -1)
|
22 |
+
|
23 |
+
if w % 2 == 1:
|
24 |
+
fx = np.fft.fftfreq(w)[: w // 2 + 2]
|
25 |
+
else:
|
26 |
+
fx = np.fft.fftfreq(w)[: w // 2 + 1]
|
27 |
+
|
28 |
+
if z is not None:
|
29 |
+
fy = np.expand_dims(fy, -1)
|
30 |
+
if z % 2 == 1:
|
31 |
+
fz = np.fft.fftfreq(z)[:, None]
|
32 |
+
else:
|
33 |
+
fz = np.fft.fftfreq(z)[:, None]
|
34 |
+
|
35 |
+
return np.sqrt(fx * fx + fy * fy + fz * fz)
|
36 |
+
|
37 |
+
|
38 |
+
def get_spectrum(freqs, decay_power, ch, h, w=0, z=0):
|
39 |
+
""" Samples a fourier image with given size and frequencies decayed by decay power
|
40 |
+
:param freqs: Bin values for the discrete fourier transform
|
41 |
+
:param decay_power: Decay power for frequency decay prop 1/f**d
|
42 |
+
:param ch: Number of channels for the resulting mask
|
43 |
+
:param h: Required, first dimension size
|
44 |
+
:param w: Optional, second dimension size
|
45 |
+
:param z: Optional, third dimension size
|
46 |
+
"""
|
47 |
+
scale = np.ones(1) / (np.maximum(freqs, np.array([1. / max(w, h, z)])) ** decay_power)
|
48 |
+
|
49 |
+
param_size = [ch] + list(freqs.shape) + [2]
|
50 |
+
param = np.random.randn(*param_size)
|
51 |
+
|
52 |
+
scale = np.expand_dims(scale, -1)[None, :]
|
53 |
+
|
54 |
+
return scale * param
|
55 |
+
|
56 |
+
|
57 |
+
def make_low_freq_image(decay, shape, ch=1):
|
58 |
+
""" Sample a low frequency image from fourier space
|
59 |
+
:param decay_power: Decay power for frequency decay prop 1/f**d
|
60 |
+
:param shape: Shape of desired mask, list up to 3 dims
|
61 |
+
:param ch: Number of channels for desired mask
|
62 |
+
"""
|
63 |
+
freqs = fftfreqnd(*shape)
|
64 |
+
spectrum = get_spectrum(freqs, decay, ch, *shape)#.reshape((1, *shape[:-1], -1))
|
65 |
+
spectrum = spectrum[:, 0] + 1j * spectrum[:, 1]
|
66 |
+
mask = np.real(np.fft.irfftn(spectrum, shape))
|
67 |
+
|
68 |
+
if len(shape) == 1:
|
69 |
+
mask = mask[:1, :shape[0]]
|
70 |
+
if len(shape) == 2:
|
71 |
+
mask = mask[:1, :shape[0], :shape[1]]
|
72 |
+
if len(shape) == 3:
|
73 |
+
mask = mask[:1, :shape[0], :shape[1], :shape[2]]
|
74 |
+
|
75 |
+
mask = mask
|
76 |
+
mask = (mask - mask.min())
|
77 |
+
mask = mask / mask.max()
|
78 |
+
return mask
|
79 |
+
|
80 |
+
|
81 |
+
def sample_lam(alpha, reformulate=False):
|
82 |
+
""" Sample a lambda from symmetric beta distribution with given alpha
|
83 |
+
:param alpha: Alpha value for beta distribution
|
84 |
+
:param reformulate: If True, uses the reformulation of [1].
|
85 |
+
"""
|
86 |
+
if reformulate:
|
87 |
+
lam = beta.rvs(alpha+1, alpha) # rvs(arg1,arg2,loc=期望, scale=标准差, size=生成随机数的个数) 从分布中生成指定个数的随机数
|
88 |
+
else:
|
89 |
+
lam = beta.rvs(alpha, alpha) # rvs(arg1,arg2,loc=期望, scale=标准差, size=生成随机数的个数) 从分布中生成指定个数的随机数
|
90 |
+
|
91 |
+
return lam
|
92 |
+
|
93 |
+
|
94 |
+
def binarise_mask(mask, lam, in_shape, max_soft=0.0):
|
95 |
+
""" Binarises a given low frequency image such that it has mean lambda.
|
96 |
+
:param mask: Low frequency image, usually the result of `make_low_freq_image`
|
97 |
+
:param lam: Mean value of final mask
|
98 |
+
:param in_shape: Shape of inputs
|
99 |
+
:param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
|
100 |
+
:return:
|
101 |
+
"""
|
102 |
+
idx = mask.reshape(-1).argsort()[::-1]
|
103 |
+
mask = mask.reshape(-1)
|
104 |
+
num = math.ceil(lam * mask.size) if random.random() > 0.5 else math.floor(lam * mask.size)
|
105 |
+
|
106 |
+
eff_soft = max_soft
|
107 |
+
if max_soft > lam or max_soft > (1-lam):
|
108 |
+
eff_soft = min(lam, 1-lam)
|
109 |
+
|
110 |
+
soft = int(mask.size * eff_soft)
|
111 |
+
num_low = num - soft
|
112 |
+
num_high = num + soft
|
113 |
+
|
114 |
+
mask[idx[:num_high]] = 1
|
115 |
+
mask[idx[num_low:]] = 0
|
116 |
+
mask[idx[num_low:num_high]] = np.linspace(1, 0, (num_high - num_low))
|
117 |
+
|
118 |
+
mask = mask.reshape((1, *in_shape))
|
119 |
+
return mask
|
120 |
+
|
121 |
+
|
122 |
+
def sample_mask(alpha, decay_power, shape, max_soft=0.0, reformulate=False):
|
123 |
+
""" Samples a mean lambda from beta distribution parametrised by alpha, creates a low frequency image and binarises
|
124 |
+
it based on this lambda
|
125 |
+
:param alpha: Alpha value for beta distribution from which to sample mean of mask
|
126 |
+
:param decay_power: Decay power for frequency decay prop 1/f**d
|
127 |
+
:param shape: Shape of desired mask, list up to 3 dims
|
128 |
+
:param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
|
129 |
+
:param reformulate: If True, uses the reformulation of [1].
|
130 |
+
"""
|
131 |
+
if isinstance(shape, int):
|
132 |
+
shape = (shape,)
|
133 |
+
|
134 |
+
# Choose lambda
|
135 |
+
lam = sample_lam(alpha, reformulate)
|
136 |
+
|
137 |
+
# Make mask, get mean / std
|
138 |
+
mask = make_low_freq_image(decay_power, shape)
|
139 |
+
mask = binarise_mask(mask, lam, shape, max_soft)
|
140 |
+
|
141 |
+
return lam, mask
|
142 |
+
|
143 |
+
|
144 |
+
def sample_and_apply(x, alpha, decay_power, shape, max_soft=0.0, reformulate=False):
|
145 |
+
"""
|
146 |
+
:param x: Image batch on which to apply fmix of shape [b, c, shape*]
|
147 |
+
:param alpha: Alpha value for beta distribution from which to sample mean of mask
|
148 |
+
:param decay_power: Decay power for frequency decay prop 1/f**d
|
149 |
+
:param shape: Shape of desired mask, list up to 3 dims
|
150 |
+
:param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
|
151 |
+
:param reformulate: If True, uses the reformulation of [1].
|
152 |
+
:return: mixed input, permutation indices, lambda value of mix,
|
153 |
+
"""
|
154 |
+
lam, mask = sample_mask(alpha, decay_power, shape, max_soft, reformulate)
|
155 |
+
index = np.random.permutation(x.shape[0])
|
156 |
+
|
157 |
+
x1, x2 = x * mask, x[index] * (1-mask)
|
158 |
+
return x1+x2, index, lam
|
159 |
+
|
160 |
+
|
161 |
+
class FMixBase:
|
162 |
+
""" FMix augmentation
|
163 |
+
Args:
|
164 |
+
decay_power (float): Decay power for frequency decay prop 1/f**d
|
165 |
+
alpha (float): Alpha value for beta distribution from which to sample mean of mask
|
166 |
+
size ([int] | [int, int] | [int, int, int]): Shape of desired mask, list up to 3 dims
|
167 |
+
max_soft (float): Softening value between 0 and 0.5 which smooths hard edges in the mask.
|
168 |
+
reformulate (bool): If True, uses the reformulation of [1].
|
169 |
+
"""
|
170 |
+
|
171 |
+
def __init__(self, decay_power=3, alpha=1, size=(32, 32), max_soft=0.0, reformulate=False):
|
172 |
+
super().__init__()
|
173 |
+
self.decay_power = decay_power
|
174 |
+
self.reformulate = reformulate
|
175 |
+
self.size = size
|
176 |
+
self.alpha = alpha
|
177 |
+
self.max_soft = max_soft
|
178 |
+
self.index = None
|
179 |
+
self.lam = None
|
180 |
+
|
181 |
+
def __call__(self, x):
|
182 |
+
raise NotImplementedError
|
183 |
+
|
184 |
+
def loss(self, *args, **kwargs):
|
185 |
+
raise NotImplementedError
|
186 |
+
|
187 |
+
|
188 |
+
if __name__ == '__main__':
|
189 |
+
# para = {'alpha':1.,'decay_power':3.,'shape':(10,10),'max_soft':0.0,'reformulate':False}
|
190 |
+
# lam, mask = sample_mask(**para)
|
191 |
+
# mask = mask.transpose(1, 2, 0)
|
192 |
+
# img1 = np.zeros((10, 10, 3))
|
193 |
+
# img2 = np.ones((10, 10, 3))
|
194 |
+
# img_gt = mask * img1 + (1. - mask) * img2
|
195 |
+
# import ipdb; ipdb.set_trace()
|
196 |
+
|
197 |
+
# test
|
198 |
+
import cv2
|
199 |
+
i1 = cv2.imread('output/ILSVRC2012_val_00000001.JPEG')
|
200 |
+
i2 = cv2.imread('output/ILSVRC2012_val_00000002.JPEG')
|
201 |
+
para = {'alpha':1.,'decay_power':3.,'shape':(256, 256),'max_soft':0.0,'reformulate':False}
|
202 |
+
lam, mask = sample_mask(**para)
|
203 |
+
mask = mask.transpose(1, 2, 0)
|
204 |
+
i = mask * i1 + (1. - mask) * i2
|
205 |
+
#i = i.astype(np.uint8)
|
206 |
+
cv2.imwrite('fmix.jpg', i)
|
basicsr/data/lab_dataset.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import random
|
3 |
+
import time
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from torch.utils import data as data
|
7 |
+
|
8 |
+
from basicsr.data.transforms import rgb2lab
|
9 |
+
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
10 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
11 |
+
from basicsr.data.fmix import sample_mask
|
12 |
+
|
13 |
+
|
14 |
+
@DATASET_REGISTRY.register()
|
15 |
+
class LabDataset(data.Dataset):
|
16 |
+
"""
|
17 |
+
Dataset used for Lab colorizaion
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, opt):
|
21 |
+
super(LabDataset, self).__init__()
|
22 |
+
self.opt = opt
|
23 |
+
# file client (io backend)
|
24 |
+
self.file_client = None
|
25 |
+
self.io_backend_opt = opt['io_backend']
|
26 |
+
self.gt_folder = opt['dataroot_gt']
|
27 |
+
|
28 |
+
meta_info_file = self.opt['meta_info_file']
|
29 |
+
assert meta_info_file is not None
|
30 |
+
if not isinstance(meta_info_file, list):
|
31 |
+
meta_info_file = [meta_info_file]
|
32 |
+
self.paths = []
|
33 |
+
for meta_info in meta_info_file:
|
34 |
+
with open(meta_info, 'r') as fin:
|
35 |
+
self.paths.extend([line.strip() for line in fin])
|
36 |
+
|
37 |
+
self.min_ab, self.max_ab = -128, 128
|
38 |
+
self.interval_ab = 4
|
39 |
+
self.ab_palette = [i for i in range(self.min_ab, self.max_ab + self.interval_ab, self.interval_ab)]
|
40 |
+
# print(self.ab_palette)
|
41 |
+
|
42 |
+
self.do_fmix = opt['do_fmix']
|
43 |
+
self.fmix_params = {'alpha':1.,'decay_power':3.,'shape':(256,256),'max_soft':0.0,'reformulate':False}
|
44 |
+
self.fmix_p = opt['fmix_p']
|
45 |
+
self.do_cutmix = opt['do_cutmix']
|
46 |
+
self.cutmix_params = {'alpha':1.}
|
47 |
+
self.cutmix_p = opt['cutmix_p']
|
48 |
+
|
49 |
+
|
50 |
+
def __getitem__(self, index):
|
51 |
+
if self.file_client is None:
|
52 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
53 |
+
|
54 |
+
# -------------------------------- Load gt images -------------------------------- #
|
55 |
+
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
|
56 |
+
gt_path = self.paths[index]
|
57 |
+
gt_size = self.opt['gt_size']
|
58 |
+
# avoid errors caused by high latency in reading files
|
59 |
+
retry = 3
|
60 |
+
while retry > 0:
|
61 |
+
try:
|
62 |
+
img_bytes = self.file_client.get(gt_path, 'gt')
|
63 |
+
except Exception as e:
|
64 |
+
logger = get_root_logger()
|
65 |
+
logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
|
66 |
+
# change another file to read
|
67 |
+
index = random.randint(0, self.__len__())
|
68 |
+
gt_path = self.paths[index]
|
69 |
+
time.sleep(1) # sleep 1s for occasional server congestion
|
70 |
+
else:
|
71 |
+
break
|
72 |
+
finally:
|
73 |
+
retry -= 1
|
74 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
75 |
+
img_gt = cv2.resize(img_gt, (gt_size, gt_size)) # TODO: 直接resize是否是最佳方案?
|
76 |
+
|
77 |
+
# -------------------------------- (Optional) CutMix & FMix -------------------------------- #
|
78 |
+
if self.do_fmix and np.random.uniform(0., 1., size=1)[0] > self.fmix_p:
|
79 |
+
with torch.no_grad():
|
80 |
+
lam, mask = sample_mask(**self.fmix_params)
|
81 |
+
|
82 |
+
fmix_index = random.randint(0, self.__len__())
|
83 |
+
fmix_img_path = self.paths[fmix_index]
|
84 |
+
fmix_img_bytes = self.file_client.get(fmix_img_path, 'gt')
|
85 |
+
fmix_img = imfrombytes(fmix_img_bytes, float32=True)
|
86 |
+
fmix_img = cv2.resize(fmix_img, (gt_size, gt_size))
|
87 |
+
|
88 |
+
mask = mask.transpose(1, 2, 0) # (1, 256, 256) -> # (256, 256, 1)
|
89 |
+
img_gt = mask * img_gt + (1. - mask) * fmix_img
|
90 |
+
img_gt = img_gt.astype(np.float32)
|
91 |
+
|
92 |
+
if self.do_cutmix and np.random.uniform(0., 1., size=1)[0] > self.cutmix_p:
|
93 |
+
with torch.no_grad():
|
94 |
+
cmix_index = random.randint(0, self.__len__())
|
95 |
+
cmix_img_path = self.paths[cmix_index]
|
96 |
+
cmix_img_bytes = self.file_client.get(cmix_img_path, 'gt')
|
97 |
+
cmix_img = imfrombytes(cmix_img_bytes, float32=True)
|
98 |
+
cmix_img = cv2.resize(cmix_img, (gt_size, gt_size))
|
99 |
+
|
100 |
+
lam = np.clip(np.random.beta(self.cutmix_params['alpha'], self.cutmix_params['alpha']), 0.3, 0.4)
|
101 |
+
bbx1, bby1, bbx2, bby2 = rand_bbox(cmix_img.shape[:2], lam)
|
102 |
+
|
103 |
+
img_gt[:, bbx1:bbx2, bby1:bby2] = cmix_img[:, bbx1:bbx2, bby1:bby2]
|
104 |
+
|
105 |
+
|
106 |
+
# ----------------------------- Get gray lq, to tentor ----------------------------- #
|
107 |
+
# convert to gray
|
108 |
+
img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2RGB)
|
109 |
+
img_l, img_ab = rgb2lab(img_gt)
|
110 |
+
|
111 |
+
target_a, target_b = self.ab2int(img_ab)
|
112 |
+
|
113 |
+
# numpy to tensor
|
114 |
+
img_l, img_ab = img2tensor([img_l, img_ab], bgr2rgb=False, float32=True)
|
115 |
+
target_a, target_b = torch.LongTensor(target_a), torch.LongTensor(target_b)
|
116 |
+
return_d = {
|
117 |
+
'lq': img_l,
|
118 |
+
'gt': img_ab,
|
119 |
+
'target_a': target_a,
|
120 |
+
'target_b': target_b,
|
121 |
+
'lq_path': gt_path,
|
122 |
+
'gt_path': gt_path
|
123 |
+
}
|
124 |
+
return return_d
|
125 |
+
|
126 |
+
def ab2int(self, img_ab):
|
127 |
+
img_a, img_b = img_ab[:, :, 0], img_ab[:, :, 1]
|
128 |
+
int_a = (img_a - self.min_ab) / self.interval_ab
|
129 |
+
int_b = (img_b - self.min_ab) / self.interval_ab
|
130 |
+
|
131 |
+
return np.round(int_a), np.round(int_b)
|
132 |
+
|
133 |
+
def __len__(self):
|
134 |
+
return len(self.paths)
|
135 |
+
|
136 |
+
|
137 |
+
def rand_bbox(size, lam):
|
138 |
+
'''cutmix 的 bbox 截取函数
|
139 |
+
Args:
|
140 |
+
size : tuple 图片尺寸 e.g (256,256)
|
141 |
+
lam : float 截取比例
|
142 |
+
Returns:
|
143 |
+
bbox 的左上角和右下角坐标
|
144 |
+
int,int,int,int
|
145 |
+
'''
|
146 |
+
W = size[0] # 截取图片的宽度
|
147 |
+
H = size[1] # 截取图片的高度
|
148 |
+
cut_rat = np.sqrt(1. - lam) # 需要截取的 bbox 比例
|
149 |
+
cut_w = np.int(W * cut_rat) # 需要截取的 bbox 宽度
|
150 |
+
cut_h = np.int(H * cut_rat) # 需要截取的 bbox 高度
|
151 |
+
|
152 |
+
cx = np.random.randint(W) # 均匀分布采样,随机选择截取的 bbox 的中心点 x 坐标
|
153 |
+
cy = np.random.randint(H) # 均匀分布采样,随机选择截取的 bbox 的中心点 y 坐标
|
154 |
+
|
155 |
+
bbx1 = np.clip(cx - cut_w // 2, 0, W) # 左上角 x 坐标
|
156 |
+
bby1 = np.clip(cy - cut_h // 2, 0, H) # 左上角 y 坐标
|
157 |
+
bbx2 = np.clip(cx + cut_w // 2, 0, W) # 右下角 x 坐标
|
158 |
+
bby2 = np.clip(cy + cut_h // 2, 0, H) # 右下角 y 坐标
|
159 |
+
return bbx1, bby1, bbx2, bby2
|
basicsr/data/prefetch_dataloader.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import queue as Queue
|
2 |
+
import threading
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
|
6 |
+
|
7 |
+
class PrefetchGenerator(threading.Thread):
|
8 |
+
"""A general prefetch generator.
|
9 |
+
|
10 |
+
Ref:
|
11 |
+
https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
|
12 |
+
|
13 |
+
Args:
|
14 |
+
generator: Python generator.
|
15 |
+
num_prefetch_queue (int): Number of prefetch queue.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, generator, num_prefetch_queue):
|
19 |
+
threading.Thread.__init__(self)
|
20 |
+
self.queue = Queue.Queue(num_prefetch_queue)
|
21 |
+
self.generator = generator
|
22 |
+
self.daemon = True
|
23 |
+
self.start()
|
24 |
+
|
25 |
+
def run(self):
|
26 |
+
for item in self.generator:
|
27 |
+
self.queue.put(item)
|
28 |
+
self.queue.put(None)
|
29 |
+
|
30 |
+
def __next__(self):
|
31 |
+
next_item = self.queue.get()
|
32 |
+
if next_item is None:
|
33 |
+
raise StopIteration
|
34 |
+
return next_item
|
35 |
+
|
36 |
+
def __iter__(self):
|
37 |
+
return self
|
38 |
+
|
39 |
+
|
40 |
+
class PrefetchDataLoader(DataLoader):
|
41 |
+
"""Prefetch version of dataloader.
|
42 |
+
|
43 |
+
Ref:
|
44 |
+
https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
|
45 |
+
|
46 |
+
TODO:
|
47 |
+
Need to test on single gpu and ddp (multi-gpu). There is a known issue in
|
48 |
+
ddp.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
num_prefetch_queue (int): Number of prefetch queue.
|
52 |
+
kwargs (dict): Other arguments for dataloader.
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(self, num_prefetch_queue, **kwargs):
|
56 |
+
self.num_prefetch_queue = num_prefetch_queue
|
57 |
+
super(PrefetchDataLoader, self).__init__(**kwargs)
|
58 |
+
|
59 |
+
def __iter__(self):
|
60 |
+
return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
|
61 |
+
|
62 |
+
|
63 |
+
class CPUPrefetcher():
|
64 |
+
"""CPU prefetcher.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
loader: Dataloader.
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __init__(self, loader):
|
71 |
+
self.ori_loader = loader
|
72 |
+
self.loader = iter(loader)
|
73 |
+
|
74 |
+
def next(self):
|
75 |
+
try:
|
76 |
+
return next(self.loader)
|
77 |
+
except StopIteration:
|
78 |
+
return None
|
79 |
+
|
80 |
+
def reset(self):
|
81 |
+
self.loader = iter(self.ori_loader)
|
82 |
+
|
83 |
+
|
84 |
+
class CUDAPrefetcher():
|
85 |
+
"""CUDA prefetcher.
|
86 |
+
|
87 |
+
Ref:
|
88 |
+
https://github.com/NVIDIA/apex/issues/304#
|
89 |
+
|
90 |
+
It may consums more GPU memory.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
loader: Dataloader.
|
94 |
+
opt (dict): Options.
|
95 |
+
"""
|
96 |
+
|
97 |
+
def __init__(self, loader, opt):
|
98 |
+
self.ori_loader = loader
|
99 |
+
self.loader = iter(loader)
|
100 |
+
self.opt = opt
|
101 |
+
self.stream = torch.cuda.Stream()
|
102 |
+
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
|
103 |
+
self.preload()
|
104 |
+
|
105 |
+
def preload(self):
|
106 |
+
try:
|
107 |
+
self.batch = next(self.loader) # self.batch is a dict
|
108 |
+
except StopIteration:
|
109 |
+
self.batch = None
|
110 |
+
return None
|
111 |
+
# put tensors to gpu
|
112 |
+
with torch.cuda.stream(self.stream):
|
113 |
+
for k, v in self.batch.items():
|
114 |
+
if torch.is_tensor(v):
|
115 |
+
self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
|
116 |
+
|
117 |
+
def next(self):
|
118 |
+
torch.cuda.current_stream().wait_stream(self.stream)
|
119 |
+
batch = self.batch
|
120 |
+
self.preload()
|
121 |
+
return batch
|
122 |
+
|
123 |
+
def reset(self):
|
124 |
+
self.loader = iter(self.ori_loader)
|
125 |
+
self.preload()
|
basicsr/data/transforms.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import random
|
5 |
+
from scipy import special
|
6 |
+
from skimage import color
|
7 |
+
|
8 |
+
|
9 |
+
def mod_crop(img, scale):
|
10 |
+
"""Mod crop images, used during testing.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
img (ndarray): Input image.
|
14 |
+
scale (int): Scale factor.
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
ndarray: Result image.
|
18 |
+
"""
|
19 |
+
img = img.copy()
|
20 |
+
if img.ndim in (2, 3):
|
21 |
+
h, w = img.shape[0], img.shape[1]
|
22 |
+
h_remainder, w_remainder = h % scale, w % scale
|
23 |
+
img = img[:h - h_remainder, :w - w_remainder, ...]
|
24 |
+
else:
|
25 |
+
raise ValueError(f'Wrong img ndim: {img.ndim}.')
|
26 |
+
return img
|
27 |
+
|
28 |
+
|
29 |
+
def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
|
30 |
+
"""Paired random crop. Support Numpy array and Tensor inputs.
|
31 |
+
|
32 |
+
It crops lists of lq and gt images with corresponding locations.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
|
36 |
+
should have the same shape. If the input is an ndarray, it will
|
37 |
+
be transformed to a list containing itself.
|
38 |
+
img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
|
39 |
+
should have the same shape. If the input is an ndarray, it will
|
40 |
+
be transformed to a list containing itself.
|
41 |
+
gt_patch_size (int): GT patch size.
|
42 |
+
scale (int): Scale factor.
|
43 |
+
gt_path (str): Path to ground-truth. Default: None.
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
list[ndarray] | ndarray: GT images and LQ images. If returned results
|
47 |
+
only have one element, just return ndarray.
|
48 |
+
"""
|
49 |
+
|
50 |
+
if not isinstance(img_gts, list):
|
51 |
+
img_gts = [img_gts]
|
52 |
+
if not isinstance(img_lqs, list):
|
53 |
+
img_lqs = [img_lqs]
|
54 |
+
|
55 |
+
# determine input type: Numpy array or Tensor
|
56 |
+
input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
|
57 |
+
|
58 |
+
if input_type == 'Tensor':
|
59 |
+
h_lq, w_lq = img_lqs[0].size()[-2:]
|
60 |
+
h_gt, w_gt = img_gts[0].size()[-2:]
|
61 |
+
else:
|
62 |
+
h_lq, w_lq = img_lqs[0].shape[0:2]
|
63 |
+
h_gt, w_gt = img_gts[0].shape[0:2]
|
64 |
+
lq_patch_size = gt_patch_size // scale
|
65 |
+
|
66 |
+
if h_gt != h_lq * scale or w_gt != w_lq * scale:
|
67 |
+
raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
|
68 |
+
f'multiplication of LQ ({h_lq}, {w_lq}).')
|
69 |
+
if h_lq < lq_patch_size or w_lq < lq_patch_size:
|
70 |
+
raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
|
71 |
+
f'({lq_patch_size}, {lq_patch_size}). '
|
72 |
+
f'Please remove {gt_path}.')
|
73 |
+
|
74 |
+
# randomly choose top and left coordinates for lq patch
|
75 |
+
top = random.randint(0, h_lq - lq_patch_size)
|
76 |
+
left = random.randint(0, w_lq - lq_patch_size)
|
77 |
+
|
78 |
+
# crop lq patch
|
79 |
+
if input_type == 'Tensor':
|
80 |
+
img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
|
81 |
+
else:
|
82 |
+
img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
|
83 |
+
|
84 |
+
# crop corresponding gt patch
|
85 |
+
top_gt, left_gt = int(top * scale), int(left * scale)
|
86 |
+
if input_type == 'Tensor':
|
87 |
+
img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
|
88 |
+
else:
|
89 |
+
img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
|
90 |
+
if len(img_gts) == 1:
|
91 |
+
img_gts = img_gts[0]
|
92 |
+
if len(img_lqs) == 1:
|
93 |
+
img_lqs = img_lqs[0]
|
94 |
+
return img_gts, img_lqs
|
95 |
+
|
96 |
+
|
97 |
+
def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
|
98 |
+
"""Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
|
99 |
+
|
100 |
+
We use vertical flip and transpose for rotation implementation.
|
101 |
+
All the images in the list use the same augmentation.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
imgs (list[ndarray] | ndarray): Images to be augmented. If the input
|
105 |
+
is an ndarray, it will be transformed to a list.
|
106 |
+
hflip (bool): Horizontal flip. Default: True.
|
107 |
+
rotation (bool): Ratotation. Default: True.
|
108 |
+
flows (list[ndarray]: Flows to be augmented. If the input is an
|
109 |
+
ndarray, it will be transformed to a list.
|
110 |
+
Dimension is (h, w, 2). Default: None.
|
111 |
+
return_status (bool): Return the status of flip and rotation.
|
112 |
+
Default: False.
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
list[ndarray] | ndarray: Augmented images and flows. If returned
|
116 |
+
results only have one element, just return ndarray.
|
117 |
+
|
118 |
+
"""
|
119 |
+
hflip = hflip and random.random() < 0.5
|
120 |
+
vflip = rotation and random.random() < 0.5
|
121 |
+
rot90 = rotation and random.random() < 0.5
|
122 |
+
|
123 |
+
def _augment(img):
|
124 |
+
if hflip: # horizontal
|
125 |
+
cv2.flip(img, 1, img)
|
126 |
+
if vflip: # vertical
|
127 |
+
cv2.flip(img, 0, img)
|
128 |
+
if rot90:
|
129 |
+
img = img.transpose(1, 0, 2)
|
130 |
+
return img
|
131 |
+
|
132 |
+
def _augment_flow(flow):
|
133 |
+
if hflip: # horizontal
|
134 |
+
cv2.flip(flow, 1, flow)
|
135 |
+
flow[:, :, 0] *= -1
|
136 |
+
if vflip: # vertical
|
137 |
+
cv2.flip(flow, 0, flow)
|
138 |
+
flow[:, :, 1] *= -1
|
139 |
+
if rot90:
|
140 |
+
flow = flow.transpose(1, 0, 2)
|
141 |
+
flow = flow[:, :, [1, 0]]
|
142 |
+
return flow
|
143 |
+
|
144 |
+
if not isinstance(imgs, list):
|
145 |
+
imgs = [imgs]
|
146 |
+
imgs = [_augment(img) for img in imgs]
|
147 |
+
if len(imgs) == 1:
|
148 |
+
imgs = imgs[0]
|
149 |
+
|
150 |
+
if flows is not None:
|
151 |
+
if not isinstance(flows, list):
|
152 |
+
flows = [flows]
|
153 |
+
flows = [_augment_flow(flow) for flow in flows]
|
154 |
+
if len(flows) == 1:
|
155 |
+
flows = flows[0]
|
156 |
+
return imgs, flows
|
157 |
+
else:
|
158 |
+
if return_status:
|
159 |
+
return imgs, (hflip, vflip, rot90)
|
160 |
+
else:
|
161 |
+
return imgs
|
162 |
+
|
163 |
+
|
164 |
+
def img_rotate(img, angle, center=None, scale=1.0, borderMode=cv2.BORDER_CONSTANT, borderValue=0.):
|
165 |
+
"""Rotate image.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
img (ndarray): Image to be rotated.
|
169 |
+
angle (float): Rotation angle in degrees. Positive values mean
|
170 |
+
counter-clockwise rotation.
|
171 |
+
center (tuple[int]): Rotation center. If the center is None,
|
172 |
+
initialize it as the center of the image. Default: None.
|
173 |
+
scale (float): Isotropic scale factor. Default: 1.0.
|
174 |
+
"""
|
175 |
+
(h, w) = img.shape[:2]
|
176 |
+
|
177 |
+
if center is None:
|
178 |
+
center = (w // 2, h // 2)
|
179 |
+
|
180 |
+
matrix = cv2.getRotationMatrix2D(center, angle, scale)
|
181 |
+
rotated_img = cv2.warpAffine(img, matrix, (w, h), borderMode=borderMode, borderValue=borderValue)
|
182 |
+
return rotated_img
|
183 |
+
|
184 |
+
|
185 |
+
def rgb2lab(img_rgb):
|
186 |
+
img_lab = color.rgb2lab(img_rgb)
|
187 |
+
img_l = img_lab[:, :, :1]
|
188 |
+
img_ab = img_lab[:, :, 1:]
|
189 |
+
return img_l, img_ab
|
190 |
+
|
191 |
+
|
192 |
+
|
basicsr/losses/__init__.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import deepcopy
|
2 |
+
|
3 |
+
from basicsr.utils import get_root_logger
|
4 |
+
from basicsr.utils.registry import LOSS_REGISTRY
|
5 |
+
from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, PerceptualLoss, WeightedTVLoss, g_path_regularize,
|
6 |
+
gradient_penalty_loss, r1_penalty)
|
7 |
+
|
8 |
+
__all__ = [
|
9 |
+
'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss',
|
10 |
+
'r1_penalty', 'g_path_regularize'
|
11 |
+
]
|
12 |
+
|
13 |
+
|
14 |
+
def build_loss(opt):
|
15 |
+
"""Build loss from options.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
opt (dict): Configuration. It must contain:
|
19 |
+
type (str): Model type.
|
20 |
+
"""
|
21 |
+
opt = deepcopy(opt)
|
22 |
+
loss_type = opt.pop('type')
|
23 |
+
loss = LOSS_REGISTRY.get(loss_type)(**opt)
|
24 |
+
logger = get_root_logger()
|
25 |
+
logger.info(f'Loss [{loss.__class__.__name__}] is created.')
|
26 |
+
return loss
|
basicsr/losses/loss_util.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
from torch.nn import functional as F
|
3 |
+
|
4 |
+
|
5 |
+
def reduce_loss(loss, reduction):
|
6 |
+
"""Reduce loss as specified.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
loss (Tensor): Elementwise loss tensor.
|
10 |
+
reduction (str): Options are 'none', 'mean' and 'sum'.
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
Tensor: Reduced loss tensor.
|
14 |
+
"""
|
15 |
+
reduction_enum = F._Reduction.get_enum(reduction)
|
16 |
+
# none: 0, elementwise_mean:1, sum: 2
|
17 |
+
if reduction_enum == 0:
|
18 |
+
return loss
|
19 |
+
elif reduction_enum == 1:
|
20 |
+
return loss.mean()
|
21 |
+
else:
|
22 |
+
return loss.sum()
|
23 |
+
|
24 |
+
|
25 |
+
def weight_reduce_loss(loss, weight=None, reduction='mean'):
|
26 |
+
"""Apply element-wise weight and reduce loss.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
loss (Tensor): Element-wise loss.
|
30 |
+
weight (Tensor): Element-wise weights. Default: None.
|
31 |
+
reduction (str): Same as built-in losses of PyTorch. Options are
|
32 |
+
'none', 'mean' and 'sum'. Default: 'mean'.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
Tensor: Loss values.
|
36 |
+
"""
|
37 |
+
# if weight is specified, apply element-wise weight
|
38 |
+
if weight is not None:
|
39 |
+
assert weight.dim() == loss.dim()
|
40 |
+
assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
|
41 |
+
loss = loss * weight
|
42 |
+
|
43 |
+
# if weight is not specified or reduction is sum, just reduce the loss
|
44 |
+
if weight is None or reduction == 'sum':
|
45 |
+
loss = reduce_loss(loss, reduction)
|
46 |
+
# if reduction is mean, then compute mean over weight region
|
47 |
+
elif reduction == 'mean':
|
48 |
+
if weight.size(1) > 1:
|
49 |
+
weight = weight.sum()
|
50 |
+
else:
|
51 |
+
weight = weight.sum() * loss.size(1)
|
52 |
+
loss = loss.sum() / weight
|
53 |
+
|
54 |
+
return loss
|
55 |
+
|
56 |
+
|
57 |
+
def weighted_loss(loss_func):
|
58 |
+
"""Create a weighted version of a given loss function.
|
59 |
+
|
60 |
+
To use this decorator, the loss function must have the signature like
|
61 |
+
`loss_func(pred, target, **kwargs)`. The function only needs to compute
|
62 |
+
element-wise loss without any reduction. This decorator will add weight
|
63 |
+
and reduction arguments to the function. The decorated function will have
|
64 |
+
the signature like `loss_func(pred, target, weight=None, reduction='mean',
|
65 |
+
**kwargs)`.
|
66 |
+
|
67 |
+
:Example:
|
68 |
+
|
69 |
+
>>> import torch
|
70 |
+
>>> @weighted_loss
|
71 |
+
>>> def l1_loss(pred, target):
|
72 |
+
>>> return (pred - target).abs()
|
73 |
+
|
74 |
+
>>> pred = torch.Tensor([0, 2, 3])
|
75 |
+
>>> target = torch.Tensor([1, 1, 1])
|
76 |
+
>>> weight = torch.Tensor([1, 0, 1])
|
77 |
+
|
78 |
+
>>> l1_loss(pred, target)
|
79 |
+
tensor(1.3333)
|
80 |
+
>>> l1_loss(pred, target, weight)
|
81 |
+
tensor(1.5000)
|
82 |
+
>>> l1_loss(pred, target, reduction='none')
|
83 |
+
tensor([1., 1., 2.])
|
84 |
+
>>> l1_loss(pred, target, weight, reduction='sum')
|
85 |
+
tensor(3.)
|
86 |
+
"""
|
87 |
+
|
88 |
+
@functools.wraps(loss_func)
|
89 |
+
def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
|
90 |
+
# get element-wise loss
|
91 |
+
loss = loss_func(pred, target, **kwargs)
|
92 |
+
loss = weight_reduce_loss(loss, weight, reduction)
|
93 |
+
return loss
|
94 |
+
|
95 |
+
return wrapper
|
basicsr/losses/losses.py
ADDED
@@ -0,0 +1,551 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import autograd as autograd
|
4 |
+
from torch import nn as nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
from basicsr.archs.vgg_arch import VGGFeatureExtractor
|
8 |
+
from basicsr.utils.registry import LOSS_REGISTRY
|
9 |
+
from .loss_util import weighted_loss
|
10 |
+
|
11 |
+
_reduction_modes = ['none', 'mean', 'sum']
|
12 |
+
|
13 |
+
|
14 |
+
@weighted_loss
|
15 |
+
def l1_loss(pred, target):
|
16 |
+
return F.l1_loss(pred, target, reduction='none')
|
17 |
+
|
18 |
+
|
19 |
+
@weighted_loss
|
20 |
+
def mse_loss(pred, target):
|
21 |
+
return F.mse_loss(pred, target, reduction='none')
|
22 |
+
|
23 |
+
|
24 |
+
@weighted_loss
|
25 |
+
def charbonnier_loss(pred, target, eps=1e-12):
|
26 |
+
return torch.sqrt((pred - target)**2 + eps)
|
27 |
+
|
28 |
+
|
29 |
+
@LOSS_REGISTRY.register()
|
30 |
+
class L1Loss(nn.Module):
|
31 |
+
"""L1 (mean absolute error, MAE) loss.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
|
35 |
+
reduction (str): Specifies the reduction to apply to the output.
|
36 |
+
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self, loss_weight=1.0, reduction='mean'):
|
40 |
+
super(L1Loss, self).__init__()
|
41 |
+
if reduction not in ['none', 'mean', 'sum']:
|
42 |
+
raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
|
43 |
+
|
44 |
+
self.loss_weight = loss_weight
|
45 |
+
self.reduction = reduction
|
46 |
+
|
47 |
+
def forward(self, pred, target, weight=None, **kwargs):
|
48 |
+
"""
|
49 |
+
Args:
|
50 |
+
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
|
51 |
+
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
|
52 |
+
weight (Tensor, optional): of shape (N, C, H, W). Element-wise
|
53 |
+
weights. Default: None.
|
54 |
+
"""
|
55 |
+
return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)
|
56 |
+
|
57 |
+
|
58 |
+
@LOSS_REGISTRY.register()
|
59 |
+
class MSELoss(nn.Module):
|
60 |
+
"""MSE (L2) loss.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
loss_weight (float): Loss weight for MSE loss. Default: 1.0.
|
64 |
+
reduction (str): Specifies the reduction to apply to the output.
|
65 |
+
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
|
66 |
+
"""
|
67 |
+
|
68 |
+
def __init__(self, loss_weight=1.0, reduction='mean'):
|
69 |
+
super(MSELoss, self).__init__()
|
70 |
+
if reduction not in ['none', 'mean', 'sum']:
|
71 |
+
raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
|
72 |
+
|
73 |
+
self.loss_weight = loss_weight
|
74 |
+
self.reduction = reduction
|
75 |
+
|
76 |
+
def forward(self, pred, target, weight=None, **kwargs):
|
77 |
+
"""
|
78 |
+
Args:
|
79 |
+
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
|
80 |
+
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
|
81 |
+
weight (Tensor, optional): of shape (N, C, H, W). Element-wise
|
82 |
+
weights. Default: None.
|
83 |
+
"""
|
84 |
+
return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction)
|
85 |
+
|
86 |
+
|
87 |
+
@LOSS_REGISTRY.register()
|
88 |
+
class CharbonnierLoss(nn.Module):
|
89 |
+
"""Charbonnier loss (one variant of Robust L1Loss, a differentiable
|
90 |
+
variant of L1Loss).
|
91 |
+
|
92 |
+
Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
|
93 |
+
Super-Resolution".
|
94 |
+
|
95 |
+
Args:
|
96 |
+
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
|
97 |
+
reduction (str): Specifies the reduction to apply to the output.
|
98 |
+
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
|
99 |
+
eps (float): A value used to control the curvature near zero.
|
100 |
+
Default: 1e-12.
|
101 |
+
"""
|
102 |
+
|
103 |
+
def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
|
104 |
+
super(CharbonnierLoss, self).__init__()
|
105 |
+
if reduction not in ['none', 'mean', 'sum']:
|
106 |
+
raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
|
107 |
+
|
108 |
+
self.loss_weight = loss_weight
|
109 |
+
self.reduction = reduction
|
110 |
+
self.eps = eps
|
111 |
+
|
112 |
+
def forward(self, pred, target, weight=None, **kwargs):
|
113 |
+
"""
|
114 |
+
Args:
|
115 |
+
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
|
116 |
+
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
|
117 |
+
weight (Tensor, optional): of shape (N, C, H, W). Element-wise
|
118 |
+
weights. Default: None.
|
119 |
+
"""
|
120 |
+
return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)
|
121 |
+
|
122 |
+
|
123 |
+
@LOSS_REGISTRY.register()
|
124 |
+
class WeightedTVLoss(L1Loss):
|
125 |
+
"""Weighted TV loss.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
loss_weight (float): Loss weight. Default: 1.0.
|
129 |
+
"""
|
130 |
+
|
131 |
+
def __init__(self, loss_weight=1.0):
|
132 |
+
super(WeightedTVLoss, self).__init__(loss_weight=loss_weight)
|
133 |
+
|
134 |
+
def forward(self, pred, weight=None):
|
135 |
+
if weight is None:
|
136 |
+
y_weight = None
|
137 |
+
x_weight = None
|
138 |
+
else:
|
139 |
+
y_weight = weight[:, :, :-1, :]
|
140 |
+
x_weight = weight[:, :, :, :-1]
|
141 |
+
|
142 |
+
y_diff = super(WeightedTVLoss, self).forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=y_weight)
|
143 |
+
x_diff = super(WeightedTVLoss, self).forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=x_weight)
|
144 |
+
|
145 |
+
loss = x_diff + y_diff
|
146 |
+
|
147 |
+
return loss
|
148 |
+
|
149 |
+
|
150 |
+
@LOSS_REGISTRY.register()
|
151 |
+
class PerceptualLoss(nn.Module):
|
152 |
+
"""Perceptual loss with commonly used style loss.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
layer_weights (dict): The weight for each layer of vgg feature.
|
156 |
+
Here is an example: {'conv5_4': 1.}, which means the conv5_4
|
157 |
+
feature layer (before relu5_4) will be extracted with weight
|
158 |
+
1.0 in calculating losses.
|
159 |
+
vgg_type (str): The type of vgg network used as feature extractor.
|
160 |
+
Default: 'vgg19'.
|
161 |
+
use_input_norm (bool): If True, normalize the input image in vgg.
|
162 |
+
Default: True.
|
163 |
+
range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
|
164 |
+
Default: False.
|
165 |
+
perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
|
166 |
+
loss will be calculated and the loss will multiplied by the
|
167 |
+
weight. Default: 1.0.
|
168 |
+
style_weight (float): If `style_weight > 0`, the style loss will be
|
169 |
+
calculated and the loss will multiplied by the weight.
|
170 |
+
Default: 0.
|
171 |
+
criterion (str): Criterion used for perceptual loss. Default: 'l1'.
|
172 |
+
"""
|
173 |
+
|
174 |
+
def __init__(self,
|
175 |
+
layer_weights,
|
176 |
+
vgg_type='vgg19',
|
177 |
+
use_input_norm=True,
|
178 |
+
range_norm=False,
|
179 |
+
perceptual_weight=1.0,
|
180 |
+
style_weight=0.,
|
181 |
+
criterion='l1'):
|
182 |
+
super(PerceptualLoss, self).__init__()
|
183 |
+
self.perceptual_weight = perceptual_weight
|
184 |
+
self.style_weight = style_weight
|
185 |
+
self.layer_weights = layer_weights
|
186 |
+
self.vgg = VGGFeatureExtractor(
|
187 |
+
layer_name_list=list(layer_weights.keys()),
|
188 |
+
vgg_type=vgg_type,
|
189 |
+
use_input_norm=use_input_norm,
|
190 |
+
range_norm=range_norm)
|
191 |
+
|
192 |
+
self.criterion_type = criterion
|
193 |
+
if self.criterion_type == 'l1':
|
194 |
+
self.criterion = torch.nn.L1Loss()
|
195 |
+
elif self.criterion_type == 'l2':
|
196 |
+
self.criterion = torch.nn.L2loss()
|
197 |
+
elif self.criterion_type == 'fro':
|
198 |
+
self.criterion = None
|
199 |
+
else:
|
200 |
+
raise NotImplementedError(f'{criterion} criterion has not been supported.')
|
201 |
+
|
202 |
+
def forward(self, x, gt):
|
203 |
+
"""Forward function.
|
204 |
+
|
205 |
+
Args:
|
206 |
+
x (Tensor): Input tensor with shape (n, c, h, w).
|
207 |
+
gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
|
208 |
+
|
209 |
+
Returns:
|
210 |
+
Tensor: Forward results.
|
211 |
+
"""
|
212 |
+
# extract vgg features
|
213 |
+
x_features = self.vgg(x)
|
214 |
+
gt_features = self.vgg(gt.detach())
|
215 |
+
|
216 |
+
# calculate perceptual loss
|
217 |
+
if self.perceptual_weight > 0:
|
218 |
+
percep_loss = 0
|
219 |
+
for k in x_features.keys():
|
220 |
+
if self.criterion_type == 'fro':
|
221 |
+
percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
|
222 |
+
else:
|
223 |
+
percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
|
224 |
+
percep_loss *= self.perceptual_weight
|
225 |
+
else:
|
226 |
+
percep_loss = None
|
227 |
+
|
228 |
+
# calculate style loss
|
229 |
+
if self.style_weight > 0:
|
230 |
+
style_loss = 0
|
231 |
+
for k in x_features.keys():
|
232 |
+
if self.criterion_type == 'fro':
|
233 |
+
style_loss += torch.norm(
|
234 |
+
self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
|
235 |
+
else:
|
236 |
+
style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
|
237 |
+
gt_features[k])) * self.layer_weights[k]
|
238 |
+
style_loss *= self.style_weight
|
239 |
+
else:
|
240 |
+
style_loss = None
|
241 |
+
|
242 |
+
return percep_loss, style_loss
|
243 |
+
|
244 |
+
def _gram_mat(self, x):
|
245 |
+
"""Calculate Gram matrix.
|
246 |
+
|
247 |
+
Args:
|
248 |
+
x (torch.Tensor): Tensor with shape of (n, c, h, w).
|
249 |
+
|
250 |
+
Returns:
|
251 |
+
torch.Tensor: Gram matrix.
|
252 |
+
"""
|
253 |
+
n, c, h, w = x.size()
|
254 |
+
features = x.view(n, c, w * h)
|
255 |
+
features_t = features.transpose(1, 2)
|
256 |
+
gram = features.bmm(features_t) / (c * h * w)
|
257 |
+
return gram
|
258 |
+
|
259 |
+
|
260 |
+
@LOSS_REGISTRY.register()
|
261 |
+
class GANLoss(nn.Module):
|
262 |
+
"""Define GAN loss.
|
263 |
+
|
264 |
+
Args:
|
265 |
+
gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
|
266 |
+
real_label_val (float): The value for real label. Default: 1.0.
|
267 |
+
fake_label_val (float): The value for fake label. Default: 0.0.
|
268 |
+
loss_weight (float): Loss weight. Default: 1.0.
|
269 |
+
Note that loss_weight is only for generators; and it is always 1.0
|
270 |
+
for discriminators.
|
271 |
+
"""
|
272 |
+
|
273 |
+
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
|
274 |
+
super(GANLoss, self).__init__()
|
275 |
+
self.gan_type = gan_type
|
276 |
+
self.loss_weight = loss_weight
|
277 |
+
self.real_label_val = real_label_val
|
278 |
+
self.fake_label_val = fake_label_val
|
279 |
+
|
280 |
+
if self.gan_type == 'vanilla':
|
281 |
+
self.loss = nn.BCEWithLogitsLoss()
|
282 |
+
elif self.gan_type == 'lsgan':
|
283 |
+
self.loss = nn.MSELoss()
|
284 |
+
elif self.gan_type == 'wgan':
|
285 |
+
self.loss = self._wgan_loss
|
286 |
+
elif self.gan_type == 'wgan_softplus':
|
287 |
+
self.loss = self._wgan_softplus_loss
|
288 |
+
elif self.gan_type == 'hinge':
|
289 |
+
self.loss = nn.ReLU()
|
290 |
+
else:
|
291 |
+
raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')
|
292 |
+
|
293 |
+
def _wgan_loss(self, input, target):
|
294 |
+
"""wgan loss.
|
295 |
+
|
296 |
+
Args:
|
297 |
+
input (Tensor): Input tensor.
|
298 |
+
target (bool): Target label.
|
299 |
+
|
300 |
+
Returns:
|
301 |
+
Tensor: wgan loss.
|
302 |
+
"""
|
303 |
+
return -input.mean() if target else input.mean()
|
304 |
+
|
305 |
+
def _wgan_softplus_loss(self, input, target):
|
306 |
+
"""wgan loss with soft plus. softplus is a smooth approximation to the
|
307 |
+
ReLU function.
|
308 |
+
|
309 |
+
In StyleGAN2, it is called:
|
310 |
+
Logistic loss for discriminator;
|
311 |
+
Non-saturating loss for generator.
|
312 |
+
|
313 |
+
Args:
|
314 |
+
input (Tensor): Input tensor.
|
315 |
+
target (bool): Target label.
|
316 |
+
|
317 |
+
Returns:
|
318 |
+
Tensor: wgan loss.
|
319 |
+
"""
|
320 |
+
return F.softplus(-input).mean() if target else F.softplus(input).mean()
|
321 |
+
|
322 |
+
def get_target_label(self, input, target_is_real):
|
323 |
+
"""Get target label.
|
324 |
+
|
325 |
+
Args:
|
326 |
+
input (Tensor): Input tensor.
|
327 |
+
target_is_real (bool): Whether the target is real or fake.
|
328 |
+
|
329 |
+
Returns:
|
330 |
+
(bool | Tensor): Target tensor. Return bool for wgan, otherwise,
|
331 |
+
return Tensor.
|
332 |
+
"""
|
333 |
+
|
334 |
+
if self.gan_type in ['wgan', 'wgan_softplus']:
|
335 |
+
return target_is_real
|
336 |
+
target_val = (self.real_label_val if target_is_real else self.fake_label_val)
|
337 |
+
return input.new_ones(input.size()) * target_val
|
338 |
+
|
339 |
+
def forward(self, input, target_is_real, is_disc=False):
|
340 |
+
"""
|
341 |
+
Args:
|
342 |
+
input (Tensor): The input for the loss module, i.e., the network
|
343 |
+
prediction.
|
344 |
+
target_is_real (bool): Whether the targe is real or fake.
|
345 |
+
is_disc (bool): Whether the loss for discriminators or not.
|
346 |
+
Default: False.
|
347 |
+
|
348 |
+
Returns:
|
349 |
+
Tensor: GAN loss value.
|
350 |
+
"""
|
351 |
+
target_label = self.get_target_label(input, target_is_real)
|
352 |
+
if self.gan_type == 'hinge':
|
353 |
+
if is_disc: # for discriminators in hinge-gan
|
354 |
+
input = -input if target_is_real else input
|
355 |
+
loss = self.loss(1 + input).mean()
|
356 |
+
else: # for generators in hinge-gan
|
357 |
+
loss = -input.mean()
|
358 |
+
else: # other gan types
|
359 |
+
loss = self.loss(input, target_label)
|
360 |
+
|
361 |
+
# loss_weight is always 1.0 for discriminators
|
362 |
+
return loss if is_disc else loss * self.loss_weight
|
363 |
+
|
364 |
+
|
365 |
+
@LOSS_REGISTRY.register()
|
366 |
+
class MultiScaleGANLoss(GANLoss):
|
367 |
+
"""
|
368 |
+
MultiScaleGANLoss accepts a list of predictions
|
369 |
+
"""
|
370 |
+
|
371 |
+
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
|
372 |
+
super(MultiScaleGANLoss, self).__init__(gan_type, real_label_val, fake_label_val, loss_weight)
|
373 |
+
|
374 |
+
def forward(self, input, target_is_real, is_disc=False):
|
375 |
+
"""
|
376 |
+
The input is a list of tensors, or a list of (a list of tensors)
|
377 |
+
"""
|
378 |
+
if isinstance(input, list):
|
379 |
+
loss = 0
|
380 |
+
for pred_i in input:
|
381 |
+
if isinstance(pred_i, list):
|
382 |
+
# Only compute GAN loss for the last layer
|
383 |
+
# in case of multiscale feature matching
|
384 |
+
pred_i = pred_i[-1]
|
385 |
+
# Safe operation: 0-dim tensor calling self.mean() does nothing
|
386 |
+
loss_tensor = super().forward(pred_i, target_is_real, is_disc).mean()
|
387 |
+
loss += loss_tensor
|
388 |
+
return loss / len(input)
|
389 |
+
else:
|
390 |
+
return super().forward(input, target_is_real, is_disc)
|
391 |
+
|
392 |
+
|
393 |
+
def r1_penalty(real_pred, real_img):
|
394 |
+
"""R1 regularization for discriminator. The core idea is to
|
395 |
+
penalize the gradient on real data alone: when the
|
396 |
+
generator distribution produces the true data distribution
|
397 |
+
and the discriminator is equal to 0 on the data manifold, the
|
398 |
+
gradient penalty ensures that the discriminator cannot create
|
399 |
+
a non-zero gradient orthogonal to the data manifold without
|
400 |
+
suffering a loss in the GAN game.
|
401 |
+
|
402 |
+
Ref:
|
403 |
+
Eq. 9 in Which training methods for GANs do actually converge.
|
404 |
+
"""
|
405 |
+
grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
|
406 |
+
grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
|
407 |
+
return grad_penalty
|
408 |
+
|
409 |
+
|
410 |
+
def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
|
411 |
+
noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
|
412 |
+
grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
|
413 |
+
path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
|
414 |
+
|
415 |
+
path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
|
416 |
+
|
417 |
+
path_penalty = (path_lengths - path_mean).pow(2).mean()
|
418 |
+
|
419 |
+
return path_penalty, path_lengths.detach().mean(), path_mean.detach()
|
420 |
+
|
421 |
+
|
422 |
+
def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
|
423 |
+
"""Calculate gradient penalty for wgan-gp.
|
424 |
+
|
425 |
+
Args:
|
426 |
+
discriminator (nn.Module): Network for the discriminator.
|
427 |
+
real_data (Tensor): Real input data.
|
428 |
+
fake_data (Tensor): Fake input data.
|
429 |
+
weight (Tensor): Weight tensor. Default: None.
|
430 |
+
|
431 |
+
Returns:
|
432 |
+
Tensor: A tensor for gradient penalty.
|
433 |
+
"""
|
434 |
+
|
435 |
+
batch_size = real_data.size(0)
|
436 |
+
alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
|
437 |
+
|
438 |
+
# interpolate between real_data and fake_data
|
439 |
+
interpolates = alpha * real_data + (1. - alpha) * fake_data
|
440 |
+
interpolates = autograd.Variable(interpolates, requires_grad=True)
|
441 |
+
|
442 |
+
disc_interpolates = discriminator(interpolates)
|
443 |
+
gradients = autograd.grad(
|
444 |
+
outputs=disc_interpolates,
|
445 |
+
inputs=interpolates,
|
446 |
+
grad_outputs=torch.ones_like(disc_interpolates),
|
447 |
+
create_graph=True,
|
448 |
+
retain_graph=True,
|
449 |
+
only_inputs=True)[0]
|
450 |
+
|
451 |
+
if weight is not None:
|
452 |
+
gradients = gradients * weight
|
453 |
+
|
454 |
+
gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
|
455 |
+
if weight is not None:
|
456 |
+
gradients_penalty /= torch.mean(weight)
|
457 |
+
|
458 |
+
return gradients_penalty
|
459 |
+
|
460 |
+
|
461 |
+
@LOSS_REGISTRY.register()
|
462 |
+
class GANFeatLoss(nn.Module):
|
463 |
+
"""Define feature matching loss for gans
|
464 |
+
|
465 |
+
Args:
|
466 |
+
criterion (str): Support 'l1', 'l2', 'charbonnier'.
|
467 |
+
loss_weight (float): Loss weight. Default: 1.0.
|
468 |
+
reduction (str): Specifies the reduction to apply to the output.
|
469 |
+
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
|
470 |
+
"""
|
471 |
+
|
472 |
+
def __init__(self, criterion='l1', loss_weight=1.0, reduction='mean'):
|
473 |
+
super(GANFeatLoss, self).__init__()
|
474 |
+
if criterion == 'l1':
|
475 |
+
self.loss_op = L1Loss(loss_weight, reduction)
|
476 |
+
elif criterion == 'l2':
|
477 |
+
self.loss_op = MSELoss(loss_weight, reduction)
|
478 |
+
elif criterion == 'charbonnier':
|
479 |
+
self.loss_op = CharbonnierLoss(loss_weight, reduction)
|
480 |
+
else:
|
481 |
+
raise ValueError(f'Unsupported loss mode: {criterion}. Supported ones are: l1|l2|charbonnier')
|
482 |
+
|
483 |
+
self.loss_weight = loss_weight
|
484 |
+
|
485 |
+
def forward(self, pred_fake, pred_real):
|
486 |
+
num_d = len(pred_fake)
|
487 |
+
loss = 0
|
488 |
+
for i in range(num_d): # for each discriminator
|
489 |
+
# last output is the final prediction, exclude it
|
490 |
+
num_intermediate_outputs = len(pred_fake[i]) - 1
|
491 |
+
for j in range(num_intermediate_outputs): # for each layer output
|
492 |
+
unweighted_loss = self.loss_op(pred_fake[i][j], pred_real[i][j].detach())
|
493 |
+
loss += unweighted_loss / num_d
|
494 |
+
return loss * self.loss_weight
|
495 |
+
|
496 |
+
|
497 |
+
class sobel_loss(nn.Module):
|
498 |
+
def __init__(self, weight=1.0):
|
499 |
+
super().__init__()
|
500 |
+
kernel_x = torch.Tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]])
|
501 |
+
kernel_y = torch.Tensor([[-1.0, -2.0, -1.0], [0.0, 0.0, 0.0], [1.0, 2.0, 1.0]])
|
502 |
+
kernel = torch.stack([kernel_x, kernel_y])
|
503 |
+
kernel.requires_grad = False
|
504 |
+
kernel = kernel.unsqueeze(1)
|
505 |
+
self.register_buffer('sobel_kernel', kernel)
|
506 |
+
self.weight = weight
|
507 |
+
|
508 |
+
def forward(self, input_tensor, target_tensor):
|
509 |
+
b, c, h, w = input_tensor.size()
|
510 |
+
input_tensor = input_tensor.view(b * c, 1, h, w)
|
511 |
+
input_edge = F.conv2d(input_tensor, self.sobel_kernel, padding=1)
|
512 |
+
input_edge = input_edge.view(b, 2*c, h, w)
|
513 |
+
|
514 |
+
target_tensor = target_tensor.view(-1, 1, h, w)
|
515 |
+
target_edge = F.conv2d(target_tensor, self.sobel_kernel, padding=1)
|
516 |
+
target_edge = target_edge.view(b, 2*c, h, w)
|
517 |
+
|
518 |
+
return self.weight * F.l1_loss(input_edge, target_edge)
|
519 |
+
|
520 |
+
|
521 |
+
@LOSS_REGISTRY.register()
|
522 |
+
class ColorfulnessLoss(nn.Module):
|
523 |
+
"""Colorfulness loss.
|
524 |
+
|
525 |
+
Args:
|
526 |
+
loss_weight (float): Loss weight for Colorfulness loss. Default: 1.0.
|
527 |
+
|
528 |
+
"""
|
529 |
+
|
530 |
+
def __init__(self, loss_weight=1.0):
|
531 |
+
super(ColorfulnessLoss, self).__init__()
|
532 |
+
|
533 |
+
self.loss_weight = loss_weight
|
534 |
+
|
535 |
+
def forward(self, pred, **kwargs):
|
536 |
+
"""
|
537 |
+
Args:
|
538 |
+
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
|
539 |
+
"""
|
540 |
+
colorfulness_loss = 0
|
541 |
+
for i in range(pred.shape[0]):
|
542 |
+
(R, G, B) = pred[i][0], pred[i][1], pred[i][2]
|
543 |
+
rg = torch.abs(R - G)
|
544 |
+
yb = torch.abs(0.5 * (R+G) - B)
|
545 |
+
(rbMean, rbStd) = (torch.mean(rg), torch.std(rg))
|
546 |
+
(ybMean, ybStd) = (torch.mean(yb), torch.std(yb))
|
547 |
+
stdRoot = torch.sqrt((rbStd ** 2) + (ybStd ** 2))
|
548 |
+
meanRoot = torch.sqrt((rbMean ** 2) + (ybMean ** 2))
|
549 |
+
colorfulness = stdRoot + (0.3 * meanRoot)
|
550 |
+
colorfulness_loss += (1 - colorfulness)
|
551 |
+
return self.loss_weight * colorfulness_loss
|
basicsr/metrics/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import deepcopy
|
2 |
+
|
3 |
+
from basicsr.utils.registry import METRIC_REGISTRY
|
4 |
+
from .psnr_ssim import calculate_psnr, calculate_ssim
|
5 |
+
from .colorfulness import calculate_cf
|
6 |
+
|
7 |
+
__all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_cf']
|
8 |
+
|
9 |
+
|
10 |
+
def calculate_metric(data, opt):
|
11 |
+
"""Calculate metric from data and options.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
opt (dict): Configuration. It must contain:
|
15 |
+
type (str): Model type.
|
16 |
+
"""
|
17 |
+
opt = deepcopy(opt)
|
18 |
+
metric_type = opt.pop('type')
|
19 |
+
metric = METRIC_REGISTRY.get(metric_type)(**data, **opt)
|
20 |
+
return metric
|
basicsr/metrics/colorfulness.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
from basicsr.utils.registry import METRIC_REGISTRY
|
4 |
+
|
5 |
+
|
6 |
+
@METRIC_REGISTRY.register()
|
7 |
+
def calculate_cf(img, **kwargs):
|
8 |
+
"""Calculate Colorfulness.
|
9 |
+
"""
|
10 |
+
(B, G, R) = cv2.split(img.astype('float'))
|
11 |
+
rg = np.absolute(R - G)
|
12 |
+
yb = np.absolute(0.5 * (R+G) - B)
|
13 |
+
(rbMean, rbStd) = (np.mean(rg), np.std(rg))
|
14 |
+
(ybMean, ybStd) = (np.mean(yb), np.std(yb))
|
15 |
+
stdRoot = np.sqrt((rbStd ** 2) + (ybStd ** 2))
|
16 |
+
meanRoot = np.sqrt((rbMean ** 2) + (ybMean ** 2))
|
17 |
+
return stdRoot + (0.3 * meanRoot)
|
basicsr/metrics/custom_fid.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from scipy import linalg
|
4 |
+
import torch.nn as nn
|
5 |
+
from torchvision import models
|
6 |
+
|
7 |
+
|
8 |
+
class INCEPTION_V3_FID(nn.Module):
|
9 |
+
"""pretrained InceptionV3 network returning feature maps"""
|
10 |
+
# Index of default block of inception to return,
|
11 |
+
# corresponds to output of final average pooling
|
12 |
+
DEFAULT_BLOCK_INDEX = 3
|
13 |
+
|
14 |
+
# Maps feature dimensionality to their output blocks indices
|
15 |
+
BLOCK_INDEX_BY_DIM = {
|
16 |
+
64: 0, # First max pooling features
|
17 |
+
192: 1, # Second max pooling featurs
|
18 |
+
768: 2, # Pre-aux classifier features
|
19 |
+
2048: 3 # Final average pooling features
|
20 |
+
}
|
21 |
+
|
22 |
+
def __init__(self,
|
23 |
+
incep_state_dict,
|
24 |
+
output_blocks=[DEFAULT_BLOCK_INDEX],
|
25 |
+
resize_input=True):
|
26 |
+
"""Build pretrained InceptionV3
|
27 |
+
Parameters
|
28 |
+
----------
|
29 |
+
output_blocks : list of int
|
30 |
+
Indices of blocks to return features of. Possible values are:
|
31 |
+
- 0: corresponds to output of first max pooling
|
32 |
+
- 1: corresponds to output of second max pooling
|
33 |
+
- 2: corresponds to output which is fed to aux classifier
|
34 |
+
- 3: corresponds to output of final average pooling
|
35 |
+
resize_input : bool
|
36 |
+
If true, bilinearly resizes input to width and height 299 before
|
37 |
+
feeding input to model. As the network without fully connected
|
38 |
+
layers is fully convolutional, it should be able to handle inputs
|
39 |
+
of arbitrary size, so resizing might not be strictly needed
|
40 |
+
normalize_input : bool
|
41 |
+
If true, normalizes the input to the statistics the pretrained
|
42 |
+
Inception network expects
|
43 |
+
"""
|
44 |
+
super(INCEPTION_V3_FID, self).__init__()
|
45 |
+
|
46 |
+
self.resize_input = resize_input
|
47 |
+
self.output_blocks = sorted(output_blocks)
|
48 |
+
self.last_needed_block = max(output_blocks)
|
49 |
+
|
50 |
+
assert self.last_needed_block <= 3, \
|
51 |
+
'Last possible output block index is 3'
|
52 |
+
|
53 |
+
self.blocks = nn.ModuleList()
|
54 |
+
|
55 |
+
inception = models.inception_v3()
|
56 |
+
inception.load_state_dict(incep_state_dict)
|
57 |
+
for param in inception.parameters():
|
58 |
+
param.requires_grad = False
|
59 |
+
|
60 |
+
# Block 0: input to maxpool1
|
61 |
+
block0 = [
|
62 |
+
inception.Conv2d_1a_3x3,
|
63 |
+
inception.Conv2d_2a_3x3,
|
64 |
+
inception.Conv2d_2b_3x3,
|
65 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
66 |
+
]
|
67 |
+
self.blocks.append(nn.Sequential(*block0))
|
68 |
+
|
69 |
+
# Block 1: maxpool1 to maxpool2
|
70 |
+
if self.last_needed_block >= 1:
|
71 |
+
block1 = [
|
72 |
+
inception.Conv2d_3b_1x1,
|
73 |
+
inception.Conv2d_4a_3x3,
|
74 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
75 |
+
]
|
76 |
+
self.blocks.append(nn.Sequential(*block1))
|
77 |
+
|
78 |
+
# Block 2: maxpool2 to aux classifier
|
79 |
+
if self.last_needed_block >= 2:
|
80 |
+
block2 = [
|
81 |
+
inception.Mixed_5b,
|
82 |
+
inception.Mixed_5c,
|
83 |
+
inception.Mixed_5d,
|
84 |
+
inception.Mixed_6a,
|
85 |
+
inception.Mixed_6b,
|
86 |
+
inception.Mixed_6c,
|
87 |
+
inception.Mixed_6d,
|
88 |
+
inception.Mixed_6e,
|
89 |
+
]
|
90 |
+
self.blocks.append(nn.Sequential(*block2))
|
91 |
+
|
92 |
+
# Block 3: aux classifier to final avgpool
|
93 |
+
if self.last_needed_block >= 3:
|
94 |
+
block3 = [
|
95 |
+
inception.Mixed_7a,
|
96 |
+
inception.Mixed_7b,
|
97 |
+
inception.Mixed_7c,
|
98 |
+
nn.AdaptiveAvgPool2d(output_size=(1, 1))
|
99 |
+
]
|
100 |
+
self.blocks.append(nn.Sequential(*block3))
|
101 |
+
|
102 |
+
def forward(self, inp):
|
103 |
+
"""Get Inception feature maps
|
104 |
+
Parameters
|
105 |
+
----------
|
106 |
+
inp : torch.autograd.Variable
|
107 |
+
Input tensor of shape Bx3xHxW. Values are expected to be in
|
108 |
+
range (0, 1)
|
109 |
+
Returns
|
110 |
+
-------
|
111 |
+
List of torch.autograd.Variable, corresponding to the selected output
|
112 |
+
block, sorted ascending by index
|
113 |
+
"""
|
114 |
+
outp = []
|
115 |
+
x = inp
|
116 |
+
|
117 |
+
if self.resize_input:
|
118 |
+
x = F.interpolate(x, size=(299, 299), mode='bilinear')
|
119 |
+
|
120 |
+
x = x.clone()
|
121 |
+
# [-1.0, 1.0] --> [0, 1.0]
|
122 |
+
x = x * 0.5 + 0.5
|
123 |
+
x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
|
124 |
+
x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
|
125 |
+
x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
|
126 |
+
|
127 |
+
for idx, block in enumerate(self.blocks):
|
128 |
+
x = block(x)
|
129 |
+
if idx in self.output_blocks:
|
130 |
+
outp.append(x)
|
131 |
+
|
132 |
+
if idx == self.last_needed_block:
|
133 |
+
break
|
134 |
+
|
135 |
+
return outp
|
136 |
+
|
137 |
+
|
138 |
+
def get_activations(images, model, batch_size, verbose=False):
|
139 |
+
"""Calculates the activations of the pool_3 layer for all images.
|
140 |
+
Params:
|
141 |
+
-- images : Numpy array of dimension (n_images, 3, hi, wi). The values
|
142 |
+
must lie between 0 and 1.
|
143 |
+
-- model : Instance of inception model
|
144 |
+
-- batch_size : the images numpy array is split into batches with
|
145 |
+
batch size batch_size. A reasonable batch size depends
|
146 |
+
on the hardware.
|
147 |
+
-- verbose : If set to True and parameter out_step is given, the number
|
148 |
+
of calculated batches is reported.
|
149 |
+
Returns:
|
150 |
+
-- A numpy array of dimension (num images, dims) that contains the
|
151 |
+
activations of the given tensor when feeding inception with the
|
152 |
+
query tensor.
|
153 |
+
"""
|
154 |
+
model.eval()
|
155 |
+
|
156 |
+
#d0 = images.shape[0]
|
157 |
+
d0 = int(images.size(0))
|
158 |
+
if batch_size > d0:
|
159 |
+
print(('Warning: batch size is bigger than the data size. '
|
160 |
+
'Setting batch size to data size'))
|
161 |
+
batch_size = d0
|
162 |
+
|
163 |
+
n_batches = d0 // batch_size
|
164 |
+
n_used_imgs = n_batches * batch_size
|
165 |
+
|
166 |
+
pred_arr = np.empty((n_used_imgs, 2048))
|
167 |
+
for i in range(n_batches):
|
168 |
+
if verbose:
|
169 |
+
print('\rPropagating batch %d/%d' % (i + 1, n_batches), end='', flush=True)
|
170 |
+
start = i * batch_size
|
171 |
+
end = start + batch_size
|
172 |
+
|
173 |
+
'''batch = torch.from_numpy(images[start:end]).type(torch.FloatTensor)
|
174 |
+
batch = Variable(batch, volatile=True)
|
175 |
+
if cfg.CUDA:
|
176 |
+
batch = batch.cuda()'''
|
177 |
+
batch = images[start:end]
|
178 |
+
|
179 |
+
pred = model(batch)[0]
|
180 |
+
|
181 |
+
# If model output is not scalar, apply global spatial average pooling.
|
182 |
+
# This happens if you choose a dimensionality not equal 2048.
|
183 |
+
if pred.shape[2] != 1 or pred.shape[3] != 1:
|
184 |
+
pred = F.adaptive_avg_pool2d(pred, output_size=(1, 1))
|
185 |
+
|
186 |
+
pred_arr[start:end] = pred.cpu().data.numpy().reshape(batch_size, -1)
|
187 |
+
|
188 |
+
if verbose:
|
189 |
+
print(' done')
|
190 |
+
|
191 |
+
return pred_arr
|
192 |
+
|
193 |
+
|
194 |
+
def calculate_activation_statistics(act):
|
195 |
+
"""Calculation of the statistics used by the FID.
|
196 |
+
Params:
|
197 |
+
-- act : Numpy array of dimension (n_images, dim (e.g. 2048)).
|
198 |
+
Returns:
|
199 |
+
-- mu : The mean over samples of the activations of the pool_3 layer of
|
200 |
+
the inception model.
|
201 |
+
-- sigma : The covariance matrix of the activations of the pool_3 layer of
|
202 |
+
the inception model.
|
203 |
+
"""
|
204 |
+
mu = np.mean(act, axis=0)
|
205 |
+
sigma = np.cov(act, rowvar=False)
|
206 |
+
return mu, sigma
|
207 |
+
|
208 |
+
|
209 |
+
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
210 |
+
"""Numpy implementation of the Frechet Distance.
|
211 |
+
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
212 |
+
and X_2 ~ N(mu_2, C_2) is
|
213 |
+
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
214 |
+
Stable version by Dougal J. Sutherland.
|
215 |
+
Params:
|
216 |
+
-- mu1 : Numpy array containing the activations of a layer of the
|
217 |
+
inception net (like returned by the function 'get_predictions')
|
218 |
+
for generated samples.
|
219 |
+
-- mu2 : The sample mean over activations, precalculated on an
|
220 |
+
representive data set.
|
221 |
+
-- sigma1: The covariance matrix over activations for generated samples.
|
222 |
+
-- sigma2: The covariance matrix over activations, precalculated on an
|
223 |
+
representive data set.
|
224 |
+
Returns:
|
225 |
+
-- : The Frechet Distance.
|
226 |
+
"""
|
227 |
+
|
228 |
+
mu1 = np.atleast_1d(mu1)
|
229 |
+
mu2 = np.atleast_1d(mu2)
|
230 |
+
|
231 |
+
sigma1 = np.atleast_2d(sigma1)
|
232 |
+
sigma2 = np.atleast_2d(sigma2)
|
233 |
+
|
234 |
+
assert mu1.shape == mu2.shape, \
|
235 |
+
'Training and test mean vectors have different lengths'
|
236 |
+
assert sigma1.shape == sigma2.shape, \
|
237 |
+
'Training and test covariances have different dimensions'
|
238 |
+
|
239 |
+
diff = mu1 - mu2
|
240 |
+
|
241 |
+
# Product might be almost singular
|
242 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
243 |
+
if not np.isfinite(covmean).all():
|
244 |
+
msg = ('fid calculation produces singular product; '
|
245 |
+
'adding %s to diagonal of cov estimates') % eps
|
246 |
+
print(msg)
|
247 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
248 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
249 |
+
|
250 |
+
# Numerical error might give slight imaginary component
|
251 |
+
if np.iscomplexobj(covmean):
|
252 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
253 |
+
m = np.max(np.abs(covmean.imag))
|
254 |
+
raise ValueError('Imaginary component {}'.format(m))
|
255 |
+
covmean = covmean.real
|
256 |
+
|
257 |
+
tr_covmean = np.trace(covmean)
|
258 |
+
|
259 |
+
return (diff.dot(diff) + np.trace(sigma1) +
|
260 |
+
np.trace(sigma2) - 2 * tr_covmean)
|
basicsr/metrics/metric_util.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
from basicsr.utils.matlab_functions import bgr2ycbcr
|
4 |
+
|
5 |
+
|
6 |
+
def reorder_image(img, input_order='HWC'):
|
7 |
+
"""Reorder images to 'HWC' order.
|
8 |
+
|
9 |
+
If the input_order is (h, w), return (h, w, 1);
|
10 |
+
If the input_order is (c, h, w), return (h, w, c);
|
11 |
+
If the input_order is (h, w, c), return as it is.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
img (ndarray): Input image.
|
15 |
+
input_order (str): Whether the input order is 'HWC' or 'CHW'.
|
16 |
+
If the input image shape is (h, w), input_order will not have
|
17 |
+
effects. Default: 'HWC'.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
ndarray: reordered image.
|
21 |
+
"""
|
22 |
+
|
23 |
+
if input_order not in ['HWC', 'CHW']:
|
24 |
+
raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' "'HWC' and 'CHW'")
|
25 |
+
if len(img.shape) == 2:
|
26 |
+
img = img[..., None]
|
27 |
+
if input_order == 'CHW':
|
28 |
+
img = img.transpose(1, 2, 0)
|
29 |
+
return img
|
30 |
+
|
31 |
+
|
32 |
+
def to_y_channel(img):
|
33 |
+
"""Change to Y channel of YCbCr.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
img (ndarray): Images with range [0, 255].
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
(ndarray): Images with range [0, 255] (float type) without round.
|
40 |
+
"""
|
41 |
+
img = img.astype(np.float32) / 255.
|
42 |
+
if img.ndim == 3 and img.shape[2] == 3:
|
43 |
+
img = bgr2ycbcr(img, y_only=True)
|
44 |
+
img = img[..., None]
|
45 |
+
return img * 255.
|
basicsr/metrics/psnr_ssim.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from basicsr.metrics.metric_util import reorder_image, to_y_channel
|
5 |
+
from basicsr.utils.registry import METRIC_REGISTRY
|
6 |
+
|
7 |
+
|
8 |
+
@METRIC_REGISTRY.register()
|
9 |
+
def calculate_psnr(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs):
|
10 |
+
"""Calculate PSNR (Peak Signal-to-Noise Ratio).
|
11 |
+
|
12 |
+
Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
|
13 |
+
|
14 |
+
Args:
|
15 |
+
img (ndarray): Images with range [0, 255].
|
16 |
+
img2 (ndarray): Images with range [0, 255].
|
17 |
+
crop_border (int): Cropped pixels in each edge of an image. These
|
18 |
+
pixels are not involved in the PSNR calculation.
|
19 |
+
input_order (str): Whether the input order is 'HWC' or 'CHW'.
|
20 |
+
Default: 'HWC'.
|
21 |
+
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
float: psnr result.
|
25 |
+
"""
|
26 |
+
|
27 |
+
assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
|
28 |
+
if input_order not in ['HWC', 'CHW']:
|
29 |
+
raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
|
30 |
+
img = reorder_image(img, input_order=input_order)
|
31 |
+
img2 = reorder_image(img2, input_order=input_order)
|
32 |
+
img = img.astype(np.float64)
|
33 |
+
img2 = img2.astype(np.float64)
|
34 |
+
|
35 |
+
if crop_border != 0:
|
36 |
+
img = img[crop_border:-crop_border, crop_border:-crop_border, ...]
|
37 |
+
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
|
38 |
+
|
39 |
+
if test_y_channel:
|
40 |
+
img = to_y_channel(img)
|
41 |
+
img2 = to_y_channel(img2)
|
42 |
+
|
43 |
+
mse = np.mean((img - img2)**2)
|
44 |
+
if mse == 0:
|
45 |
+
return float('inf')
|
46 |
+
return 20. * np.log10(255. / np.sqrt(mse))
|
47 |
+
|
48 |
+
|
49 |
+
def _ssim(img, img2):
|
50 |
+
"""Calculate SSIM (structural similarity) for one channel images.
|
51 |
+
|
52 |
+
It is called by func:`calculate_ssim`.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
img (ndarray): Images with range [0, 255] with order 'HWC'.
|
56 |
+
img2 (ndarray): Images with range [0, 255] with order 'HWC'.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
float: ssim result.
|
60 |
+
"""
|
61 |
+
|
62 |
+
c1 = (0.01 * 255)**2
|
63 |
+
c2 = (0.03 * 255)**2
|
64 |
+
|
65 |
+
img = img.astype(np.float64)
|
66 |
+
img2 = img2.astype(np.float64)
|
67 |
+
kernel = cv2.getGaussianKernel(11, 1.5)
|
68 |
+
window = np.outer(kernel, kernel.transpose())
|
69 |
+
|
70 |
+
mu1 = cv2.filter2D(img, -1, window)[5:-5, 5:-5]
|
71 |
+
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
72 |
+
mu1_sq = mu1**2
|
73 |
+
mu2_sq = mu2**2
|
74 |
+
mu1_mu2 = mu1 * mu2
|
75 |
+
sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq
|
76 |
+
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
77 |
+
sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
78 |
+
|
79 |
+
ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2))
|
80 |
+
return ssim_map.mean()
|
81 |
+
|
82 |
+
|
83 |
+
@METRIC_REGISTRY.register()
|
84 |
+
def calculate_ssim(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs):
|
85 |
+
"""Calculate SSIM (structural similarity).
|
86 |
+
|
87 |
+
Ref:
|
88 |
+
Image quality assessment: From error visibility to structural similarity
|
89 |
+
|
90 |
+
The results are the same as that of the official released MATLAB code in
|
91 |
+
https://ece.uwaterloo.ca/~z70wang/research/ssim/.
|
92 |
+
|
93 |
+
For three-channel images, SSIM is calculated for each channel and then
|
94 |
+
averaged.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
img (ndarray): Images with range [0, 255].
|
98 |
+
img2 (ndarray): Images with range [0, 255].
|
99 |
+
crop_border (int): Cropped pixels in each edge of an image. These
|
100 |
+
pixels are not involved in the SSIM calculation.
|
101 |
+
input_order (str): Whether the input order is 'HWC' or 'CHW'.
|
102 |
+
Default: 'HWC'.
|
103 |
+
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
float: ssim result.
|
107 |
+
"""
|
108 |
+
|
109 |
+
assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
|
110 |
+
if input_order not in ['HWC', 'CHW']:
|
111 |
+
raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
|
112 |
+
img = reorder_image(img, input_order=input_order)
|
113 |
+
img2 = reorder_image(img2, input_order=input_order)
|
114 |
+
img = img.astype(np.float64)
|
115 |
+
img2 = img2.astype(np.float64)
|
116 |
+
|
117 |
+
if crop_border != 0:
|
118 |
+
img = img[crop_border:-crop_border, crop_border:-crop_border, ...]
|
119 |
+
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
|
120 |
+
|
121 |
+
if test_y_channel:
|
122 |
+
img = to_y_channel(img)
|
123 |
+
img2 = to_y_channel(img2)
|
124 |
+
|
125 |
+
ssims = []
|
126 |
+
for i in range(img.shape[2]):
|
127 |
+
ssims.append(_ssim(img[..., i], img2[..., i]))
|
128 |
+
return np.array(ssims).mean()
|
basicsr/models/__init__.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from copy import deepcopy
|
3 |
+
from os import path as osp
|
4 |
+
|
5 |
+
from basicsr.utils import get_root_logger, scandir
|
6 |
+
from basicsr.utils.registry import MODEL_REGISTRY
|
7 |
+
|
8 |
+
__all__ = ['build_model']
|
9 |
+
|
10 |
+
# automatically scan and import model modules for registry
|
11 |
+
# scan all the files under the 'models' folder and collect files ending with
|
12 |
+
# '_model.py'
|
13 |
+
model_folder = osp.dirname(osp.abspath(__file__))
|
14 |
+
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
|
15 |
+
# import all the model modules
|
16 |
+
_model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames]
|
17 |
+
|
18 |
+
|
19 |
+
def build_model(opt):
|
20 |
+
"""Build model from options.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
opt (dict): Configuration. It must contain:
|
24 |
+
model_type (str): Model type.
|
25 |
+
"""
|
26 |
+
opt = deepcopy(opt)
|
27 |
+
model = MODEL_REGISTRY.get(opt['model_type'])(opt)
|
28 |
+
logger = get_root_logger()
|
29 |
+
logger.info(f'Model [{model.__class__.__name__}] is created.')
|
30 |
+
return model
|
basicsr/models/base_model.py
ADDED
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import torch
|
4 |
+
from collections import OrderedDict
|
5 |
+
from copy import deepcopy
|
6 |
+
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
7 |
+
|
8 |
+
from basicsr.models import lr_scheduler as lr_scheduler
|
9 |
+
from basicsr.utils import get_root_logger
|
10 |
+
from basicsr.utils.dist_util import master_only
|
11 |
+
|
12 |
+
|
13 |
+
class BaseModel():
|
14 |
+
"""Base model."""
|
15 |
+
|
16 |
+
def __init__(self, opt):
|
17 |
+
self.opt = opt
|
18 |
+
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
|
19 |
+
self.is_train = opt['is_train']
|
20 |
+
self.schedulers = []
|
21 |
+
self.optimizers = []
|
22 |
+
|
23 |
+
def feed_data(self, data):
|
24 |
+
pass
|
25 |
+
|
26 |
+
def optimize_parameters(self):
|
27 |
+
pass
|
28 |
+
|
29 |
+
def get_current_visuals(self):
|
30 |
+
pass
|
31 |
+
|
32 |
+
def save(self, epoch, current_iter):
|
33 |
+
"""Save networks and training state."""
|
34 |
+
pass
|
35 |
+
|
36 |
+
def validation(self, dataloader, current_iter, tb_logger, save_img=False):
|
37 |
+
"""Validation function.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
dataloader (torch.utils.data.DataLoader): Validation dataloader.
|
41 |
+
current_iter (int): Current iteration.
|
42 |
+
tb_logger (tensorboard logger): Tensorboard logger.
|
43 |
+
save_img (bool): Whether to save images. Default: False.
|
44 |
+
"""
|
45 |
+
if self.opt['dist']:
|
46 |
+
self.dist_validation(dataloader, current_iter, tb_logger, save_img)
|
47 |
+
else:
|
48 |
+
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
|
49 |
+
|
50 |
+
def _initialize_best_metric_results(self, dataset_name):
|
51 |
+
"""Initialize the best metric results dict for recording the best metric value and iteration."""
|
52 |
+
if hasattr(self, 'best_metric_results') and dataset_name in self.best_metric_results:
|
53 |
+
return
|
54 |
+
elif not hasattr(self, 'best_metric_results'):
|
55 |
+
self.best_metric_results = dict()
|
56 |
+
|
57 |
+
# add a dataset record
|
58 |
+
record = dict()
|
59 |
+
for metric, content in self.opt['val']['metrics'].items():
|
60 |
+
better = content.get('better', 'higher')
|
61 |
+
init_val = float('-inf') if better == 'higher' else float('inf')
|
62 |
+
record[metric] = dict(better=better, val=init_val, iter=-1)
|
63 |
+
self.best_metric_results[dataset_name] = record
|
64 |
+
|
65 |
+
def _update_best_metric_result(self, dataset_name, metric, val, current_iter):
|
66 |
+
if self.best_metric_results[dataset_name][metric]['better'] == 'higher':
|
67 |
+
if val >= self.best_metric_results[dataset_name][metric]['val']:
|
68 |
+
self.best_metric_results[dataset_name][metric]['val'] = val
|
69 |
+
self.best_metric_results[dataset_name][metric]['iter'] = current_iter
|
70 |
+
else:
|
71 |
+
if val <= self.best_metric_results[dataset_name][metric]['val']:
|
72 |
+
self.best_metric_results[dataset_name][metric]['val'] = val
|
73 |
+
self.best_metric_results[dataset_name][metric]['iter'] = current_iter
|
74 |
+
|
75 |
+
def model_ema(self, decay=0.999):
|
76 |
+
net_g = self.get_bare_model(self.net_g)
|
77 |
+
|
78 |
+
net_g_params = dict(net_g.named_parameters())
|
79 |
+
net_g_ema_params = dict(self.net_g_ema.named_parameters())
|
80 |
+
|
81 |
+
for k in net_g_ema_params.keys():
|
82 |
+
net_g_ema_params[k].data.mul_(decay).add_(net_g_params[k].data, alpha=1 - decay)
|
83 |
+
|
84 |
+
def get_current_log(self):
|
85 |
+
return self.log_dict
|
86 |
+
|
87 |
+
def model_to_device(self, net):
|
88 |
+
"""Model to device. It also warps models with DistributedDataParallel
|
89 |
+
or DataParallel.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
net (nn.Module)
|
93 |
+
"""
|
94 |
+
net = net.to(self.device)
|
95 |
+
if self.opt['dist']:
|
96 |
+
find_unused_parameters = self.opt.get('find_unused_parameters', False)
|
97 |
+
net = DistributedDataParallel(
|
98 |
+
net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters)
|
99 |
+
elif self.opt['num_gpu'] > 1:
|
100 |
+
net = DataParallel(net)
|
101 |
+
return net
|
102 |
+
|
103 |
+
def get_optimizer(self, optim_type, params, lr, **kwargs):
|
104 |
+
if optim_type == 'Adam':
|
105 |
+
optimizer = torch.optim.Adam(params, lr, **kwargs)
|
106 |
+
elif optim_type == 'AdamW':
|
107 |
+
optimizer = torch.optim.AdamW(params, lr, **kwargs)
|
108 |
+
else:
|
109 |
+
raise NotImplementedError(f'optimizer {optim_type} is not supperted yet.')
|
110 |
+
return optimizer
|
111 |
+
|
112 |
+
def setup_schedulers(self):
|
113 |
+
"""Set up schedulers."""
|
114 |
+
train_opt = self.opt['train']
|
115 |
+
scheduler_type = train_opt['scheduler'].pop('type')
|
116 |
+
if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']:
|
117 |
+
for optimizer in self.optimizers:
|
118 |
+
self.schedulers.append(lr_scheduler.MultiStepRestartLR(optimizer, **train_opt['scheduler']))
|
119 |
+
elif scheduler_type == 'CosineAnnealingRestartLR':
|
120 |
+
for optimizer in self.optimizers:
|
121 |
+
self.schedulers.append(lr_scheduler.CosineAnnealingRestartLR(optimizer, **train_opt['scheduler']))
|
122 |
+
else:
|
123 |
+
raise NotImplementedError(f'Scheduler {scheduler_type} is not implemented yet.')
|
124 |
+
|
125 |
+
def get_bare_model(self, net):
|
126 |
+
"""Get bare model, especially under wrapping with
|
127 |
+
DistributedDataParallel or DataParallel.
|
128 |
+
"""
|
129 |
+
if isinstance(net, (DataParallel, DistributedDataParallel)):
|
130 |
+
net = net.module
|
131 |
+
return net
|
132 |
+
|
133 |
+
@master_only
|
134 |
+
def print_network(self, net):
|
135 |
+
"""Print the str and parameter number of a network.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
net (nn.Module)
|
139 |
+
"""
|
140 |
+
if isinstance(net, (DataParallel, DistributedDataParallel)):
|
141 |
+
net_cls_str = f'{net.__class__.__name__} - {net.module.__class__.__name__}'
|
142 |
+
else:
|
143 |
+
net_cls_str = f'{net.__class__.__name__}'
|
144 |
+
|
145 |
+
net = self.get_bare_model(net)
|
146 |
+
net_str = str(net)
|
147 |
+
net_params = sum(map(lambda x: x.numel(), net.parameters()))
|
148 |
+
|
149 |
+
logger = get_root_logger()
|
150 |
+
logger.info(f'Network: {net_cls_str}, with parameters: {net_params:,d}')
|
151 |
+
logger.info(net_str)
|
152 |
+
|
153 |
+
def _set_lr(self, lr_groups_l):
|
154 |
+
"""Set learning rate for warmup.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
lr_groups_l (list): List for lr_groups, each for an optimizer.
|
158 |
+
"""
|
159 |
+
for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
|
160 |
+
for param_group, lr in zip(optimizer.param_groups, lr_groups):
|
161 |
+
param_group['lr'] = lr
|
162 |
+
|
163 |
+
def _get_init_lr(self):
|
164 |
+
"""Get the initial lr, which is set by the scheduler.
|
165 |
+
"""
|
166 |
+
init_lr_groups_l = []
|
167 |
+
for optimizer in self.optimizers:
|
168 |
+
init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
|
169 |
+
return init_lr_groups_l
|
170 |
+
|
171 |
+
def update_learning_rate(self, current_iter, warmup_iter=-1):
|
172 |
+
"""Update learning rate.
|
173 |
+
|
174 |
+
Args:
|
175 |
+
current_iter (int): Current iteration.
|
176 |
+
warmup_iter (int): Warmup iter numbers. -1 for no warmup.
|
177 |
+
Default: -1.
|
178 |
+
"""
|
179 |
+
if current_iter > 1:
|
180 |
+
for scheduler in self.schedulers:
|
181 |
+
scheduler.step()
|
182 |
+
# set up warm-up learning rate
|
183 |
+
if current_iter < warmup_iter:
|
184 |
+
# get initial lr for each group
|
185 |
+
init_lr_g_l = self._get_init_lr()
|
186 |
+
# modify warming-up learning rates
|
187 |
+
# currently only support linearly warm up
|
188 |
+
warm_up_lr_l = []
|
189 |
+
for init_lr_g in init_lr_g_l:
|
190 |
+
warm_up_lr_l.append([v / warmup_iter * current_iter for v in init_lr_g])
|
191 |
+
# set learning rate
|
192 |
+
self._set_lr(warm_up_lr_l)
|
193 |
+
|
194 |
+
def get_current_learning_rate(self):
|
195 |
+
return [param_group['lr'] for param_group in self.optimizers[0].param_groups]
|
196 |
+
|
197 |
+
@master_only
|
198 |
+
def save_network(self, net, net_label, current_iter, param_key='params'):
|
199 |
+
"""Save networks.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
net (nn.Module | list[nn.Module]): Network(s) to be saved.
|
203 |
+
net_label (str): Network label.
|
204 |
+
current_iter (int): Current iter number.
|
205 |
+
param_key (str | list[str]): The parameter key(s) to save network.
|
206 |
+
Default: 'params'.
|
207 |
+
"""
|
208 |
+
if current_iter == -1:
|
209 |
+
current_iter = 'latest'
|
210 |
+
save_filename = f'{net_label}_{current_iter}.pth'
|
211 |
+
save_path = os.path.join(self.opt['path']['models'], save_filename)
|
212 |
+
|
213 |
+
net = net if isinstance(net, list) else [net]
|
214 |
+
param_key = param_key if isinstance(param_key, list) else [param_key]
|
215 |
+
assert len(net) == len(param_key), 'The lengths of net and param_key should be the same.'
|
216 |
+
|
217 |
+
save_dict = {}
|
218 |
+
for net_, param_key_ in zip(net, param_key):
|
219 |
+
net_ = self.get_bare_model(net_)
|
220 |
+
state_dict = net_.state_dict()
|
221 |
+
for key, param in state_dict.items():
|
222 |
+
if key.startswith('module.'): # remove unnecessary 'module.'
|
223 |
+
key = key[7:]
|
224 |
+
state_dict[key] = param.cpu()
|
225 |
+
save_dict[param_key_] = state_dict
|
226 |
+
|
227 |
+
# avoid occasional writing errors
|
228 |
+
retry = 3
|
229 |
+
while retry > 0:
|
230 |
+
try:
|
231 |
+
torch.save(save_dict, save_path)
|
232 |
+
except Exception as e:
|
233 |
+
logger = get_root_logger()
|
234 |
+
logger.warning(f'Save model error: {e}, remaining retry times: {retry - 1}')
|
235 |
+
time.sleep(1)
|
236 |
+
else:
|
237 |
+
break
|
238 |
+
finally:
|
239 |
+
retry -= 1
|
240 |
+
if retry == 0:
|
241 |
+
logger.warning(f'Still cannot save {save_path}. Just ignore it.')
|
242 |
+
# raise IOError(f'Cannot save {save_path}.')
|
243 |
+
|
244 |
+
def _print_different_keys_loading(self, crt_net, load_net, strict=True):
|
245 |
+
"""Print keys with different name or different size when loading models.
|
246 |
+
|
247 |
+
1. Print keys with different names.
|
248 |
+
2. If strict=False, print the same key but with different tensor size.
|
249 |
+
It also ignore these keys with different sizes (not load).
|
250 |
+
|
251 |
+
Args:
|
252 |
+
crt_net (torch model): Current network.
|
253 |
+
load_net (dict): Loaded network.
|
254 |
+
strict (bool): Whether strictly loaded. Default: True.
|
255 |
+
"""
|
256 |
+
crt_net = self.get_bare_model(crt_net)
|
257 |
+
crt_net = crt_net.state_dict()
|
258 |
+
crt_net_keys = set(crt_net.keys())
|
259 |
+
load_net_keys = set(load_net.keys())
|
260 |
+
|
261 |
+
logger = get_root_logger()
|
262 |
+
if crt_net_keys != load_net_keys:
|
263 |
+
logger.warning('Current net - loaded net:')
|
264 |
+
for v in sorted(list(crt_net_keys - load_net_keys)):
|
265 |
+
logger.warning(f' {v}')
|
266 |
+
logger.warning('Loaded net - current net:')
|
267 |
+
for v in sorted(list(load_net_keys - crt_net_keys)):
|
268 |
+
logger.warning(f' {v}')
|
269 |
+
|
270 |
+
# check the size for the same keys
|
271 |
+
if not strict:
|
272 |
+
common_keys = crt_net_keys & load_net_keys
|
273 |
+
for k in common_keys:
|
274 |
+
if crt_net[k].size() != load_net[k].size():
|
275 |
+
logger.warning(f'Size different, ignore [{k}]: crt_net: '
|
276 |
+
f'{crt_net[k].shape}; load_net: {load_net[k].shape}')
|
277 |
+
load_net[k + '.ignore'] = load_net.pop(k)
|
278 |
+
|
279 |
+
def load_network(self, net, load_path, strict=True, param_key='params'):
|
280 |
+
"""Load network.
|
281 |
+
|
282 |
+
Args:
|
283 |
+
load_path (str): The path of networks to be loaded.
|
284 |
+
net (nn.Module): Network.
|
285 |
+
strict (bool): Whether strictly loaded.
|
286 |
+
param_key (str): The parameter key of loaded network. If set to
|
287 |
+
None, use the root 'path'.
|
288 |
+
Default: 'params'.
|
289 |
+
"""
|
290 |
+
logger = get_root_logger()
|
291 |
+
net = self.get_bare_model(net)
|
292 |
+
load_net = torch.load(load_path, map_location=lambda storage, loc: storage)
|
293 |
+
if param_key is not None:
|
294 |
+
if param_key not in load_net and 'params' in load_net:
|
295 |
+
param_key = 'params'
|
296 |
+
logger.info('Loading: params_ema does not exist, use params.')
|
297 |
+
load_net = load_net[param_key]
|
298 |
+
logger.info(f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].')
|
299 |
+
# remove unnecessary 'module.'
|
300 |
+
for k, v in deepcopy(load_net).items():
|
301 |
+
if k.startswith('module.'):
|
302 |
+
load_net[k[7:]] = v
|
303 |
+
load_net.pop(k)
|
304 |
+
self._print_different_keys_loading(net, load_net, strict)
|
305 |
+
net.load_state_dict(load_net, strict=strict)
|
306 |
+
|
307 |
+
@master_only
|
308 |
+
def save_training_state(self, epoch, current_iter):
|
309 |
+
"""Save training states during training, which will be used for
|
310 |
+
resuming.
|
311 |
+
|
312 |
+
Args:
|
313 |
+
epoch (int): Current epoch.
|
314 |
+
current_iter (int): Current iteration.
|
315 |
+
"""
|
316 |
+
if current_iter != -1:
|
317 |
+
state = {'epoch': epoch, 'iter': current_iter, 'optimizers': [], 'schedulers': []}
|
318 |
+
for o in self.optimizers:
|
319 |
+
state['optimizers'].append(o.state_dict())
|
320 |
+
for s in self.schedulers:
|
321 |
+
state['schedulers'].append(s.state_dict())
|
322 |
+
save_filename = f'{current_iter}.state'
|
323 |
+
save_path = os.path.join(self.opt['path']['training_states'], save_filename)
|
324 |
+
|
325 |
+
# avoid occasional writing errors
|
326 |
+
retry = 3
|
327 |
+
while retry > 0:
|
328 |
+
try:
|
329 |
+
torch.save(state, save_path)
|
330 |
+
except Exception as e:
|
331 |
+
logger = get_root_logger()
|
332 |
+
logger.warning(f'Save training state error: {e}, remaining retry times: {retry - 1}')
|
333 |
+
time.sleep(1)
|
334 |
+
else:
|
335 |
+
break
|
336 |
+
finally:
|
337 |
+
retry -= 1
|
338 |
+
if retry == 0:
|
339 |
+
logger.warning(f'Still cannot save {save_path}. Just ignore it.')
|
340 |
+
# raise IOError(f'Cannot save {save_path}.')
|
341 |
+
|
342 |
+
def resume_training(self, resume_state):
|
343 |
+
"""Reload the optimizers and schedulers for resumed training.
|
344 |
+
|
345 |
+
Args:
|
346 |
+
resume_state (dict): Resume state.
|
347 |
+
"""
|
348 |
+
resume_optimizers = resume_state['optimizers']
|
349 |
+
resume_schedulers = resume_state['schedulers']
|
350 |
+
assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
|
351 |
+
assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
|
352 |
+
for i, o in enumerate(resume_optimizers):
|
353 |
+
self.optimizers[i].load_state_dict(o)
|
354 |
+
for i, s in enumerate(resume_schedulers):
|
355 |
+
self.schedulers[i].load_state_dict(s)
|
356 |
+
|
357 |
+
def reduce_loss_dict(self, loss_dict):
|
358 |
+
"""reduce loss dict.
|
359 |
+
|
360 |
+
In distributed training, it averages the losses among different GPUs .
|
361 |
+
|
362 |
+
Args:
|
363 |
+
loss_dict (OrderedDict): Loss dict.
|
364 |
+
"""
|
365 |
+
with torch.no_grad():
|
366 |
+
if self.opt['dist']:
|
367 |
+
keys = []
|
368 |
+
losses = []
|
369 |
+
for name, value in loss_dict.items():
|
370 |
+
keys.append(name)
|
371 |
+
losses.append(value)
|
372 |
+
losses = torch.stack(losses, 0)
|
373 |
+
torch.distributed.reduce(losses, dst=0)
|
374 |
+
if self.opt['rank'] == 0:
|
375 |
+
losses /= self.opt['world_size']
|
376 |
+
loss_dict = {key: loss for key, loss in zip(keys, losses)}
|
377 |
+
|
378 |
+
log_dict = OrderedDict()
|
379 |
+
for name, value in loss_dict.items():
|
380 |
+
log_dict[name] = value.mean().item()
|
381 |
+
|
382 |
+
return log_dict
|
basicsr/models/color_model.py
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from collections import OrderedDict
|
4 |
+
from os import path as osp
|
5 |
+
from tqdm import tqdm
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from basicsr.archs import build_network
|
9 |
+
from basicsr.losses import build_loss
|
10 |
+
from basicsr.metrics import calculate_metric
|
11 |
+
from basicsr.utils import get_root_logger, imwrite, tensor2img
|
12 |
+
from basicsr.utils.img_util import tensor_lab2rgb
|
13 |
+
from basicsr.utils.dist_util import master_only
|
14 |
+
from basicsr.utils.registry import MODEL_REGISTRY
|
15 |
+
from .base_model import BaseModel
|
16 |
+
from basicsr.metrics.custom_fid import INCEPTION_V3_FID, get_activations, calculate_activation_statistics, calculate_frechet_distance
|
17 |
+
from basicsr.utils.color_enhance import color_enhacne_blend
|
18 |
+
|
19 |
+
|
20 |
+
@MODEL_REGISTRY.register()
|
21 |
+
class ColorModel(BaseModel):
|
22 |
+
"""Colorization model for single image colorization."""
|
23 |
+
|
24 |
+
def __init__(self, opt):
|
25 |
+
super(ColorModel, self).__init__(opt)
|
26 |
+
|
27 |
+
# define network net_g
|
28 |
+
self.net_g = build_network(opt['network_g'])
|
29 |
+
self.net_g = self.model_to_device(self.net_g)
|
30 |
+
self.print_network(self.net_g)
|
31 |
+
|
32 |
+
# load pretrained model for net_g
|
33 |
+
load_path = self.opt['path'].get('pretrain_network_g', None)
|
34 |
+
if load_path is not None:
|
35 |
+
param_key = self.opt['path'].get('param_key_g', 'params')
|
36 |
+
self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
|
37 |
+
|
38 |
+
if self.is_train:
|
39 |
+
self.init_training_settings()
|
40 |
+
|
41 |
+
def init_training_settings(self):
|
42 |
+
train_opt = self.opt['train']
|
43 |
+
|
44 |
+
self.ema_decay = train_opt.get('ema_decay', 0)
|
45 |
+
if self.ema_decay > 0:
|
46 |
+
logger = get_root_logger()
|
47 |
+
logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
|
48 |
+
# define network net_g with Exponential Moving Average (EMA)
|
49 |
+
# net_g_ema is used only for testing on one GPU and saving
|
50 |
+
# There is no need to wrap with DistributedDataParallel
|
51 |
+
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
|
52 |
+
# load pretrained model
|
53 |
+
load_path = self.opt['path'].get('pretrain_network_g', None)
|
54 |
+
if load_path is not None:
|
55 |
+
self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
|
56 |
+
else:
|
57 |
+
self.model_ema(0) # copy net_g weight
|
58 |
+
self.net_g_ema.eval()
|
59 |
+
|
60 |
+
# define network net_d
|
61 |
+
self.net_d = build_network(self.opt['network_d'])
|
62 |
+
self.net_d = self.model_to_device(self.net_d)
|
63 |
+
self.print_network(self.net_d)
|
64 |
+
|
65 |
+
# load pretrained model for net_d
|
66 |
+
load_path = self.opt['path'].get('pretrain_network_d', None)
|
67 |
+
if load_path is not None:
|
68 |
+
param_key = self.opt['path'].get('param_key_d', 'params')
|
69 |
+
self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key)
|
70 |
+
|
71 |
+
self.net_g.train()
|
72 |
+
self.net_d.train()
|
73 |
+
|
74 |
+
# define losses
|
75 |
+
if train_opt.get('pixel_opt'):
|
76 |
+
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
|
77 |
+
else:
|
78 |
+
self.cri_pix = None
|
79 |
+
|
80 |
+
if train_opt.get('perceptual_opt'):
|
81 |
+
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
|
82 |
+
else:
|
83 |
+
self.cri_perceptual = None
|
84 |
+
|
85 |
+
if train_opt.get('gan_opt'):
|
86 |
+
self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
|
87 |
+
else:
|
88 |
+
self.cri_gan = None
|
89 |
+
|
90 |
+
if self.cri_pix is None and self.cri_perceptual is None:
|
91 |
+
raise ValueError('Both pixel and perceptual losses are None.')
|
92 |
+
|
93 |
+
if train_opt.get('colorfulness_opt'):
|
94 |
+
self.cri_colorfulness = build_loss(train_opt['colorfulness_opt']).to(self.device)
|
95 |
+
else:
|
96 |
+
self.cri_colorfulness = None
|
97 |
+
|
98 |
+
# set up optimizers and schedulers
|
99 |
+
self.setup_optimizers()
|
100 |
+
self.setup_schedulers()
|
101 |
+
|
102 |
+
# set real dataset cache for fid metric computing
|
103 |
+
self.real_mu, self.real_sigma = None, None
|
104 |
+
if self.opt['val'].get('metrics') is not None and self.opt['val']['metrics'].get('fid') is not None:
|
105 |
+
self._prepare_inception_model_fid()
|
106 |
+
|
107 |
+
def setup_optimizers(self):
|
108 |
+
train_opt = self.opt['train']
|
109 |
+
# optim_params_g = []
|
110 |
+
# for k, v in self.net_g.named_parameters():
|
111 |
+
# if v.requires_grad:
|
112 |
+
# optim_params_g.append(v)
|
113 |
+
# else:
|
114 |
+
# logger = get_root_logger()
|
115 |
+
# logger.warning(f'Params {k} will not be optimized.')
|
116 |
+
optim_params_g = self.net_g.parameters()
|
117 |
+
|
118 |
+
# optimizer g
|
119 |
+
optim_type = train_opt['optim_g'].pop('type')
|
120 |
+
self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g'])
|
121 |
+
self.optimizers.append(self.optimizer_g)
|
122 |
+
|
123 |
+
# optimizer d
|
124 |
+
optim_type = train_opt['optim_d'].pop('type')
|
125 |
+
self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
|
126 |
+
self.optimizers.append(self.optimizer_d)
|
127 |
+
|
128 |
+
def feed_data(self, data):
|
129 |
+
self.lq = data['lq'].to(self.device)
|
130 |
+
self.lq_rgb = tensor_lab2rgb(torch.cat([self.lq, torch.zeros_like(self.lq), torch.zeros_like(self.lq)], dim=1))
|
131 |
+
if 'gt' in data:
|
132 |
+
self.gt = data['gt'].to(self.device)
|
133 |
+
self.gt_lab = torch.cat([self.lq, self.gt], dim=1)
|
134 |
+
self.gt_rgb = tensor_lab2rgb(self.gt_lab)
|
135 |
+
|
136 |
+
if self.opt['train'].get('color_enhance', False):
|
137 |
+
for i in range(self.gt_rgb.shape[0]):
|
138 |
+
self.gt_rgb[i] = color_enhacne_blend(self.gt_rgb[i], factor=self.opt['train'].get('color_enhance_factor'))
|
139 |
+
|
140 |
+
def optimize_parameters(self, current_iter):
|
141 |
+
# optimize net_g
|
142 |
+
for p in self.net_d.parameters():
|
143 |
+
p.requires_grad = False
|
144 |
+
self.optimizer_g.zero_grad()
|
145 |
+
|
146 |
+
self.output_ab = self.net_g(self.lq_rgb)
|
147 |
+
self.output_lab = torch.cat([self.lq, self.output_ab], dim=1)
|
148 |
+
self.output_rgb = tensor_lab2rgb(self.output_lab)
|
149 |
+
|
150 |
+
l_g_total = 0
|
151 |
+
loss_dict = OrderedDict()
|
152 |
+
# pixel loss
|
153 |
+
if self.cri_pix:
|
154 |
+
l_g_pix = self.cri_pix(self.output_ab, self.gt)
|
155 |
+
l_g_total += l_g_pix
|
156 |
+
loss_dict['l_g_pix'] = l_g_pix
|
157 |
+
|
158 |
+
# perceptual loss
|
159 |
+
if self.cri_perceptual:
|
160 |
+
l_g_percep, l_g_style = self.cri_perceptual(self.output_rgb, self.gt_rgb)
|
161 |
+
if l_g_percep is not None:
|
162 |
+
l_g_total += l_g_percep
|
163 |
+
loss_dict['l_g_percep'] = l_g_percep
|
164 |
+
if l_g_style is not None:
|
165 |
+
l_g_total += l_g_style
|
166 |
+
loss_dict['l_g_style'] = l_g_style
|
167 |
+
# gan loss
|
168 |
+
if self.cri_gan:
|
169 |
+
fake_g_pred = self.net_d(self.output_rgb)
|
170 |
+
l_g_gan = self.cri_gan(fake_g_pred, target_is_real=True, is_disc=False)
|
171 |
+
l_g_total += l_g_gan
|
172 |
+
loss_dict['l_g_gan'] = l_g_gan
|
173 |
+
# colorfulness loss
|
174 |
+
if self.cri_colorfulness:
|
175 |
+
l_g_color = self.cri_colorfulness(self.output_rgb)
|
176 |
+
l_g_total += l_g_color
|
177 |
+
loss_dict['l_g_color'] = l_g_color
|
178 |
+
|
179 |
+
l_g_total.backward()
|
180 |
+
self.optimizer_g.step()
|
181 |
+
|
182 |
+
# optimize net_d
|
183 |
+
for p in self.net_d.parameters():
|
184 |
+
p.requires_grad = True
|
185 |
+
self.optimizer_d.zero_grad()
|
186 |
+
|
187 |
+
real_d_pred = self.net_d(self.gt_rgb)
|
188 |
+
fake_d_pred = self.net_d(self.output_rgb.detach())
|
189 |
+
l_d = self.cri_gan(real_d_pred, target_is_real=True, is_disc=True) + self.cri_gan(fake_d_pred, target_is_real=False, is_disc=True)
|
190 |
+
loss_dict['l_d'] = l_d
|
191 |
+
loss_dict['real_score'] = real_d_pred.detach().mean()
|
192 |
+
loss_dict['fake_score'] = fake_d_pred.detach().mean()
|
193 |
+
|
194 |
+
l_d.backward()
|
195 |
+
self.optimizer_d.step()
|
196 |
+
|
197 |
+
self.log_dict = self.reduce_loss_dict(loss_dict)
|
198 |
+
|
199 |
+
if self.ema_decay > 0:
|
200 |
+
self.model_ema(decay=self.ema_decay)
|
201 |
+
|
202 |
+
def get_current_visuals(self):
|
203 |
+
out_dict = OrderedDict()
|
204 |
+
out_dict['lq'] = self.lq_rgb.detach().cpu()
|
205 |
+
out_dict['result'] = self.output_rgb.detach().cpu()
|
206 |
+
if self.opt['logger'].get('save_snapshot_verbose', False): # only for verbose
|
207 |
+
self.output_lab_chroma = torch.cat([torch.ones_like(self.lq) * 50, self.output_ab], dim=1)
|
208 |
+
self.output_rgb_chroma = tensor_lab2rgb(self.output_lab_chroma)
|
209 |
+
out_dict['result_chroma'] = self.output_rgb_chroma.detach().cpu()
|
210 |
+
|
211 |
+
if hasattr(self, 'gt'):
|
212 |
+
out_dict['gt'] = self.gt_rgb.detach().cpu()
|
213 |
+
if self.opt['logger'].get('save_snapshot_verbose', False): # only for verbose
|
214 |
+
self.gt_lab_chroma = torch.cat([torch.ones_like(self.lq) * 50, self.gt], dim=1)
|
215 |
+
self.gt_rgb_chroma = tensor_lab2rgb(self.gt_lab_chroma)
|
216 |
+
out_dict['gt_chroma'] = self.gt_rgb_chroma.detach().cpu()
|
217 |
+
return out_dict
|
218 |
+
|
219 |
+
def test(self):
|
220 |
+
if hasattr(self, 'net_g_ema'):
|
221 |
+
self.net_g_ema.eval()
|
222 |
+
with torch.no_grad():
|
223 |
+
self.output_ab = self.net_g_ema(self.lq_rgb)
|
224 |
+
self.output_lab = torch.cat([self.lq, self.output_ab], dim=1)
|
225 |
+
self.output_rgb = tensor_lab2rgb(self.output_lab)
|
226 |
+
else:
|
227 |
+
self.net_g.eval()
|
228 |
+
with torch.no_grad():
|
229 |
+
self.output_ab = self.net_g(self.lq_rgb)
|
230 |
+
self.output_lab = torch.cat([self.lq, self.output_ab], dim=1)
|
231 |
+
self.output_rgb = tensor_lab2rgb(self.output_lab)
|
232 |
+
self.net_g.train()
|
233 |
+
|
234 |
+
def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
235 |
+
if self.opt['rank'] == 0:
|
236 |
+
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
|
237 |
+
|
238 |
+
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
239 |
+
dataset_name = dataloader.dataset.opt['name']
|
240 |
+
with_metrics = self.opt['val'].get('metrics') is not None
|
241 |
+
use_pbar = self.opt['val'].get('pbar', False)
|
242 |
+
|
243 |
+
if with_metrics and not hasattr(self, 'metric_results'): # only execute in the first run
|
244 |
+
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
|
245 |
+
# initialize the best metric results for each dataset_name (supporting multiple validation datasets)
|
246 |
+
if with_metrics:
|
247 |
+
self._initialize_best_metric_results(dataset_name)
|
248 |
+
# zero self.metric_results
|
249 |
+
if with_metrics:
|
250 |
+
self.metric_results = {metric: 0 for metric in self.metric_results}
|
251 |
+
|
252 |
+
metric_data = dict()
|
253 |
+
if use_pbar:
|
254 |
+
pbar = tqdm(total=len(dataloader), unit='image')
|
255 |
+
|
256 |
+
if self.opt['val']['metrics'].get('fid') is not None:
|
257 |
+
fake_acts_set, acts_set = [], []
|
258 |
+
|
259 |
+
for idx, val_data in enumerate(dataloader):
|
260 |
+
# if idx == 100:
|
261 |
+
# break
|
262 |
+
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
|
263 |
+
if hasattr(self, 'gt'):
|
264 |
+
del self.gt
|
265 |
+
self.feed_data(val_data)
|
266 |
+
self.test()
|
267 |
+
|
268 |
+
visuals = self.get_current_visuals()
|
269 |
+
sr_img = tensor2img([visuals['result']])
|
270 |
+
metric_data['img'] = sr_img
|
271 |
+
if 'gt' in visuals:
|
272 |
+
gt_img = tensor2img([visuals['gt']])
|
273 |
+
metric_data['img2'] = gt_img
|
274 |
+
|
275 |
+
torch.cuda.empty_cache()
|
276 |
+
|
277 |
+
if save_img:
|
278 |
+
if self.opt['is_train']:
|
279 |
+
save_dir = osp.join(self.opt['path']['visualization'], img_name)
|
280 |
+
for key in visuals:
|
281 |
+
save_path = os.path.join(save_dir, '{}_{}.png'.format(current_iter, key))
|
282 |
+
img = tensor2img(visuals[key])
|
283 |
+
imwrite(img, save_path)
|
284 |
+
else:
|
285 |
+
if self.opt['val']['suffix']:
|
286 |
+
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
|
287 |
+
f'{img_name}_{self.opt["val"]["suffix"]}.png')
|
288 |
+
else:
|
289 |
+
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
|
290 |
+
f'{img_name}_{self.opt["name"]}.png')
|
291 |
+
imwrite(sr_img, save_img_path)
|
292 |
+
|
293 |
+
if with_metrics:
|
294 |
+
# calculate metrics
|
295 |
+
for name, opt_ in self.opt['val']['metrics'].items():
|
296 |
+
if name == 'fid':
|
297 |
+
pred, gt = visuals['result'].cuda(), visuals['gt'].cuda()
|
298 |
+
fake_act = get_activations(pred, self.inception_model_fid, 1)
|
299 |
+
fake_acts_set.append(fake_act)
|
300 |
+
if self.real_mu is None:
|
301 |
+
real_act = get_activations(gt, self.inception_model_fid, 1)
|
302 |
+
acts_set.append(real_act)
|
303 |
+
else:
|
304 |
+
self.metric_results[name] += calculate_metric(metric_data, opt_)
|
305 |
+
if use_pbar:
|
306 |
+
pbar.update(1)
|
307 |
+
pbar.set_description(f'Test {img_name}')
|
308 |
+
if use_pbar:
|
309 |
+
pbar.close()
|
310 |
+
|
311 |
+
if with_metrics:
|
312 |
+
if self.opt['val']['metrics'].get('fid') is not None:
|
313 |
+
if self.real_mu is None:
|
314 |
+
acts_set = np.concatenate(acts_set, 0)
|
315 |
+
self.real_mu, self.real_sigma = calculate_activation_statistics(acts_set)
|
316 |
+
fake_acts_set = np.concatenate(fake_acts_set, 0)
|
317 |
+
fake_mu, fake_sigma = calculate_activation_statistics(fake_acts_set)
|
318 |
+
|
319 |
+
fid_score = calculate_frechet_distance(self.real_mu, self.real_sigma, fake_mu, fake_sigma)
|
320 |
+
self.metric_results['fid'] = fid_score
|
321 |
+
|
322 |
+
for metric in self.metric_results.keys():
|
323 |
+
if metric != 'fid':
|
324 |
+
self.metric_results[metric] /= (idx + 1)
|
325 |
+
# update the best metric result
|
326 |
+
self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter)
|
327 |
+
|
328 |
+
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
|
329 |
+
|
330 |
+
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
|
331 |
+
log_str = f'Validation {dataset_name}\n'
|
332 |
+
for metric, value in self.metric_results.items():
|
333 |
+
log_str += f'\t # {metric}: {value:.4f}'
|
334 |
+
if hasattr(self, 'best_metric_results'):
|
335 |
+
log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
|
336 |
+
f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
|
337 |
+
log_str += '\n'
|
338 |
+
|
339 |
+
logger = get_root_logger()
|
340 |
+
logger.info(log_str)
|
341 |
+
if tb_logger:
|
342 |
+
for metric, value in self.metric_results.items():
|
343 |
+
tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter)
|
344 |
+
|
345 |
+
def _prepare_inception_model_fid(self, path='pretrain/inception_v3_google-1a9a5a14.pth'):
|
346 |
+
incep_state_dict = torch.load(path, map_location='cpu')
|
347 |
+
block_idx = INCEPTION_V3_FID.BLOCK_INDEX_BY_DIM[2048]
|
348 |
+
self.inception_model_fid = INCEPTION_V3_FID(incep_state_dict, [block_idx])
|
349 |
+
self.inception_model_fid.cuda()
|
350 |
+
self.inception_model_fid.eval()
|
351 |
+
|
352 |
+
@master_only
|
353 |
+
def save_training_images(self, current_iter):
|
354 |
+
visuals = self.get_current_visuals()
|
355 |
+
save_dir = osp.join(self.opt['root_path'], 'experiments', self.opt['name'], 'training_images_snapshot')
|
356 |
+
os.makedirs(save_dir, exist_ok=True)
|
357 |
+
|
358 |
+
for key in visuals:
|
359 |
+
save_path = os.path.join(save_dir, '{}_{}.png'.format(current_iter, key))
|
360 |
+
img = tensor2img(visuals[key])
|
361 |
+
imwrite(img, save_path)
|
362 |
+
|
363 |
+
def save(self, epoch, current_iter):
|
364 |
+
if hasattr(self, 'net_g_ema'):
|
365 |
+
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
|
366 |
+
else:
|
367 |
+
self.save_network(self.net_g, 'net_g', current_iter)
|
368 |
+
self.save_network(self.net_d, 'net_d', current_iter)
|
369 |
+
self.save_training_state(epoch, current_iter)
|
basicsr/models/lr_scheduler.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from collections import Counter
|
3 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
4 |
+
|
5 |
+
|
6 |
+
class MultiStepRestartLR(_LRScheduler):
|
7 |
+
""" MultiStep with restarts learning rate scheme.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
optimizer (torch.nn.optimizer): Torch optimizer.
|
11 |
+
milestones (list): Iterations that will decrease learning rate.
|
12 |
+
gamma (float): Decrease ratio. Default: 0.1.
|
13 |
+
restarts (list): Restart iterations. Default: [0].
|
14 |
+
restart_weights (list): Restart weights at each restart iteration.
|
15 |
+
Default: [1].
|
16 |
+
last_epoch (int): Used in _LRScheduler. Default: -1.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1):
|
20 |
+
self.milestones = Counter(milestones)
|
21 |
+
self.gamma = gamma
|
22 |
+
self.restarts = restarts
|
23 |
+
self.restart_weights = restart_weights
|
24 |
+
assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.'
|
25 |
+
super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
|
26 |
+
|
27 |
+
def get_lr(self):
|
28 |
+
if self.last_epoch in self.restarts:
|
29 |
+
weight = self.restart_weights[self.restarts.index(self.last_epoch)]
|
30 |
+
return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
|
31 |
+
if self.last_epoch not in self.milestones:
|
32 |
+
return [group['lr'] for group in self.optimizer.param_groups]
|
33 |
+
return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups]
|
34 |
+
|
35 |
+
|
36 |
+
def get_position_from_periods(iteration, cumulative_period):
|
37 |
+
"""Get the position from a period list.
|
38 |
+
|
39 |
+
It will return the index of the right-closest number in the period list.
|
40 |
+
For example, the cumulative_period = [100, 200, 300, 400],
|
41 |
+
if iteration == 50, return 0;
|
42 |
+
if iteration == 210, return 2;
|
43 |
+
if iteration == 300, return 2.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
iteration (int): Current iteration.
|
47 |
+
cumulative_period (list[int]): Cumulative period list.
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
int: The position of the right-closest number in the period list.
|
51 |
+
"""
|
52 |
+
for i, period in enumerate(cumulative_period):
|
53 |
+
if iteration <= period:
|
54 |
+
return i
|
55 |
+
|
56 |
+
|
57 |
+
class CosineAnnealingRestartLR(_LRScheduler):
|
58 |
+
""" Cosine annealing with restarts learning rate scheme.
|
59 |
+
|
60 |
+
An example of config:
|
61 |
+
periods = [10, 10, 10, 10]
|
62 |
+
restart_weights = [1, 0.5, 0.5, 0.5]
|
63 |
+
eta_min=1e-7
|
64 |
+
|
65 |
+
It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
|
66 |
+
scheduler will restart with the weights in restart_weights.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
optimizer (torch.nn.optimizer): Torch optimizer.
|
70 |
+
periods (list): Period for each cosine anneling cycle.
|
71 |
+
restart_weights (list): Restart weights at each restart iteration.
|
72 |
+
Default: [1].
|
73 |
+
eta_min (float): The minimum lr. Default: 0.
|
74 |
+
last_epoch (int): Used in _LRScheduler. Default: -1.
|
75 |
+
"""
|
76 |
+
|
77 |
+
def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1):
|
78 |
+
self.periods = periods
|
79 |
+
self.restart_weights = restart_weights
|
80 |
+
self.eta_min = eta_min
|
81 |
+
assert (len(self.periods) == len(
|
82 |
+
self.restart_weights)), 'periods and restart_weights should have the same length.'
|
83 |
+
self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))]
|
84 |
+
super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
|
85 |
+
|
86 |
+
def get_lr(self):
|
87 |
+
idx = get_position_from_periods(self.last_epoch, self.cumulative_period)
|
88 |
+
current_weight = self.restart_weights[idx]
|
89 |
+
nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
|
90 |
+
current_period = self.periods[idx]
|
91 |
+
|
92 |
+
return [
|
93 |
+
self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
|
94 |
+
(1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period)))
|
95 |
+
for base_lr in self.base_lrs
|
96 |
+
]
|
basicsr/train.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import logging
|
3 |
+
import math
|
4 |
+
import time
|
5 |
+
import torch
|
6 |
+
import warnings
|
7 |
+
|
8 |
+
warnings.filterwarnings("ignore")
|
9 |
+
|
10 |
+
from os import path as osp
|
11 |
+
|
12 |
+
from basicsr.data import build_dataloader, build_dataset
|
13 |
+
from basicsr.data.data_sampler import EnlargedSampler
|
14 |
+
from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
|
15 |
+
from basicsr.models import build_model
|
16 |
+
from basicsr.utils import (AvgTimer, MessageLogger, check_resume, get_env_info, get_root_logger, get_time_str,
|
17 |
+
init_tb_logger, init_wandb_logger, make_exp_dirs, mkdir_and_rename, scandir)
|
18 |
+
from basicsr.utils.options import copy_opt_file, dict2str, parse_options
|
19 |
+
|
20 |
+
|
21 |
+
def init_tb_loggers(opt):
|
22 |
+
# initialize wandb logger before tensorboard logger to allow proper sync
|
23 |
+
if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project')
|
24 |
+
is not None) and ('debug' not in opt['name']):
|
25 |
+
assert opt['logger'].get('use_tb_logger') is True, ('should turn on tensorboard when using wandb')
|
26 |
+
init_wandb_logger(opt)
|
27 |
+
tb_logger = None
|
28 |
+
if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name']:
|
29 |
+
tb_logger = init_tb_logger(log_dir=osp.join(opt['root_path'], 'tb_logger', opt['name']))
|
30 |
+
return tb_logger
|
31 |
+
|
32 |
+
|
33 |
+
def create_train_val_dataloader(opt, logger):
|
34 |
+
# create train and val dataloaders
|
35 |
+
train_loader, val_loaders = None, []
|
36 |
+
for phase, dataset_opt in opt['datasets'].items():
|
37 |
+
if phase == 'train':
|
38 |
+
dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1)
|
39 |
+
train_set = build_dataset(dataset_opt)
|
40 |
+
train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio)
|
41 |
+
train_loader = build_dataloader(
|
42 |
+
train_set,
|
43 |
+
dataset_opt,
|
44 |
+
num_gpu=opt['num_gpu'],
|
45 |
+
dist=opt['dist'],
|
46 |
+
sampler=train_sampler,
|
47 |
+
seed=opt['manual_seed'])
|
48 |
+
|
49 |
+
num_iter_per_epoch = math.ceil(
|
50 |
+
len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size']))
|
51 |
+
total_iters = int(opt['train']['total_iter'])
|
52 |
+
total_epochs = math.ceil(total_iters / (num_iter_per_epoch))
|
53 |
+
logger.info('Training statistics:'
|
54 |
+
f'\n\tNumber of train images: {len(train_set)}'
|
55 |
+
f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}'
|
56 |
+
f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}'
|
57 |
+
f'\n\tWorld size (gpu number): {opt["world_size"]}'
|
58 |
+
f'\n\tRequire iter number per epoch: {num_iter_per_epoch}'
|
59 |
+
f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.')
|
60 |
+
elif phase.split('_')[0] == 'val':
|
61 |
+
val_set = build_dataset(dataset_opt)
|
62 |
+
val_loader = build_dataloader(
|
63 |
+
val_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed'])
|
64 |
+
logger.info(f'Number of val images/folders in {dataset_opt["name"]}: {len(val_set)}')
|
65 |
+
val_loaders.append(val_loader)
|
66 |
+
else:
|
67 |
+
raise ValueError(f'Dataset phase {phase} is not recognized.')
|
68 |
+
|
69 |
+
return train_loader, train_sampler, val_loaders, total_epochs, total_iters
|
70 |
+
|
71 |
+
|
72 |
+
def load_resume_state(opt):
|
73 |
+
resume_state_path = None
|
74 |
+
if opt['auto_resume']:
|
75 |
+
state_path = osp.join(opt['root_path'], 'experiments', opt['name'], 'training_states')
|
76 |
+
if osp.isdir(state_path):
|
77 |
+
states = list(scandir(state_path, suffix='state', recursive=False, full_path=False))
|
78 |
+
if len(states) != 0:
|
79 |
+
states = [float(v.split('.state')[0]) for v in states]
|
80 |
+
resume_state_path = osp.join(state_path, f'{max(states):.0f}.state')
|
81 |
+
opt['path']['resume_state'] = resume_state_path
|
82 |
+
else:
|
83 |
+
if opt['path'].get('resume_state'):
|
84 |
+
resume_state_path = opt['path']['resume_state']
|
85 |
+
|
86 |
+
if resume_state_path is None:
|
87 |
+
resume_state = None
|
88 |
+
else:
|
89 |
+
device_id = torch.cuda.current_device()
|
90 |
+
resume_state = torch.load(resume_state_path, map_location=lambda storage, loc: storage.cuda(device_id))
|
91 |
+
check_resume(opt, resume_state['iter'])
|
92 |
+
return resume_state
|
93 |
+
|
94 |
+
|
95 |
+
def train_pipeline(root_path):
|
96 |
+
# parse options, set distributed setting, set ramdom seed
|
97 |
+
opt, args = parse_options(root_path, is_train=True)
|
98 |
+
opt['root_path'] = root_path
|
99 |
+
|
100 |
+
torch.backends.cudnn.benchmark = True
|
101 |
+
# torch.backends.cudnn.deterministic = True
|
102 |
+
|
103 |
+
# load resume states if necessary
|
104 |
+
resume_state = load_resume_state(opt)
|
105 |
+
# mkdir for experiments and logger
|
106 |
+
if resume_state is None:
|
107 |
+
make_exp_dirs(opt)
|
108 |
+
if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name'] and opt['rank'] == 0:
|
109 |
+
mkdir_and_rename(osp.join(opt['root_path'], 'tb_logger', opt['name']))
|
110 |
+
|
111 |
+
# copy the yml file to the experiment root
|
112 |
+
copy_opt_file(args.opt, opt['path']['experiments_root'])
|
113 |
+
|
114 |
+
# WARNING: should not use get_root_logger in the above codes, including the called functions
|
115 |
+
# Otherwise the logger will not be properly initialized
|
116 |
+
log_file = osp.join(opt['path']['log'], f"train_{opt['name']}_{get_time_str()}.log")
|
117 |
+
logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
|
118 |
+
logger.info(get_env_info())
|
119 |
+
logger.info(dict2str(opt))
|
120 |
+
# initialize wandb and tb loggers
|
121 |
+
tb_logger = init_tb_loggers(opt)
|
122 |
+
|
123 |
+
# create train and validation dataloaders
|
124 |
+
result = create_train_val_dataloader(opt, logger)
|
125 |
+
train_loader, train_sampler, val_loaders, total_epochs, total_iters = result
|
126 |
+
|
127 |
+
# create model
|
128 |
+
model = build_model(opt)
|
129 |
+
if resume_state: # resume training
|
130 |
+
model.resume_training(resume_state) # handle optimizers and schedulers
|
131 |
+
logger.info(f"Resuming training from epoch: {resume_state['epoch']}, " f"iter: {resume_state['iter']}.")
|
132 |
+
start_epoch = resume_state['epoch']
|
133 |
+
current_iter = resume_state['iter']
|
134 |
+
else:
|
135 |
+
start_epoch = 0
|
136 |
+
current_iter = 0
|
137 |
+
|
138 |
+
# create message logger (formatted outputs)
|
139 |
+
msg_logger = MessageLogger(opt, current_iter, tb_logger)
|
140 |
+
|
141 |
+
# dataloader prefetcher
|
142 |
+
prefetch_mode = opt['datasets']['train'].get('prefetch_mode')
|
143 |
+
if prefetch_mode is None or prefetch_mode == 'cpu':
|
144 |
+
prefetcher = CPUPrefetcher(train_loader)
|
145 |
+
elif prefetch_mode == 'cuda':
|
146 |
+
prefetcher = CUDAPrefetcher(train_loader, opt)
|
147 |
+
logger.info(f'Use {prefetch_mode} prefetch dataloader')
|
148 |
+
if opt['datasets']['train'].get('pin_memory') is not True:
|
149 |
+
raise ValueError('Please set pin_memory=True for CUDAPrefetcher.')
|
150 |
+
else:
|
151 |
+
raise ValueError(f'Wrong prefetch_mode {prefetch_mode}.' "Supported ones are: None, 'cuda', 'cpu'.")
|
152 |
+
|
153 |
+
# training
|
154 |
+
logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter}')
|
155 |
+
data_timer, iter_timer = AvgTimer(), AvgTimer()
|
156 |
+
start_time = time.time()
|
157 |
+
|
158 |
+
for epoch in range(start_epoch, total_epochs + 1):
|
159 |
+
train_sampler.set_epoch(epoch)
|
160 |
+
prefetcher.reset()
|
161 |
+
train_data = prefetcher.next()
|
162 |
+
|
163 |
+
while train_data is not None:
|
164 |
+
data_timer.record()
|
165 |
+
|
166 |
+
current_iter += 1
|
167 |
+
if current_iter > total_iters:
|
168 |
+
break
|
169 |
+
# update learning rate
|
170 |
+
model.update_learning_rate(current_iter, warmup_iter=opt['train'].get('warmup_iter', -1))
|
171 |
+
# training
|
172 |
+
model.feed_data(train_data)
|
173 |
+
model.optimize_parameters(current_iter)
|
174 |
+
iter_timer.record()
|
175 |
+
if current_iter == 1:
|
176 |
+
# reset start time in msg_logger for more accurate eta_time
|
177 |
+
# not work in resume mode
|
178 |
+
msg_logger.reset_start_time()
|
179 |
+
# log
|
180 |
+
if current_iter % opt['logger']['print_freq'] == 0:
|
181 |
+
log_vars = {'epoch': epoch, 'iter': current_iter}
|
182 |
+
log_vars.update({'lrs': model.get_current_learning_rate()})
|
183 |
+
log_vars.update({'time': iter_timer.get_avg_time(), 'data_time': data_timer.get_avg_time()})
|
184 |
+
log_vars.update(model.get_current_log())
|
185 |
+
msg_logger(log_vars)
|
186 |
+
|
187 |
+
# save training images snapshot save_snapshot_freq
|
188 |
+
if opt['logger'][
|
189 |
+
'save_snapshot_freq'] is not None and current_iter % opt['logger']['save_snapshot_freq'] == 0:
|
190 |
+
model.save_training_images(current_iter)
|
191 |
+
|
192 |
+
# save models and training states
|
193 |
+
if current_iter % opt['logger']['save_checkpoint_freq'] == 0:
|
194 |
+
logger.info('Saving models and training states.')
|
195 |
+
model.save(epoch, current_iter)
|
196 |
+
|
197 |
+
# validation
|
198 |
+
if opt.get('val') is not None and (current_iter % opt['val']['val_freq'] == 0):
|
199 |
+
if len(val_loaders) > 1:
|
200 |
+
logger.warning('Multiple validation datasets are *only* supported by SRModel.')
|
201 |
+
for val_loader in val_loaders:
|
202 |
+
model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
|
203 |
+
|
204 |
+
data_timer.start()
|
205 |
+
iter_timer.start()
|
206 |
+
train_data = prefetcher.next()
|
207 |
+
# end of iter
|
208 |
+
|
209 |
+
# end of epoch
|
210 |
+
|
211 |
+
consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time)))
|
212 |
+
logger.info(f'End of training. Time consumed: {consumed_time}')
|
213 |
+
logger.info('Save the latest model.')
|
214 |
+
model.save(epoch=-1, current_iter=-1) # -1 stands for the latest
|
215 |
+
if opt.get('val') is not None:
|
216 |
+
for val_loader in val_loaders:
|
217 |
+
model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
|
218 |
+
if tb_logger:
|
219 |
+
tb_logger.close()
|
220 |
+
|
221 |
+
|
222 |
+
if __name__ == '__main__':
|
223 |
+
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
|
224 |
+
train_pipeline(root_path)
|
basicsr/utils/__init__.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .diffjpeg import DiffJPEG
|
2 |
+
from .file_client import FileClient
|
3 |
+
from .img_process_util import USMSharp, usm_sharp
|
4 |
+
from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img
|
5 |
+
from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger
|
6 |
+
from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt
|
7 |
+
|
8 |
+
__all__ = [
|
9 |
+
# file_client.py
|
10 |
+
'FileClient',
|
11 |
+
# img_util.py
|
12 |
+
'img2tensor',
|
13 |
+
'tensor2img',
|
14 |
+
'imfrombytes',
|
15 |
+
'imwrite',
|
16 |
+
'crop_border',
|
17 |
+
# logger.py
|
18 |
+
'MessageLogger',
|
19 |
+
'AvgTimer',
|
20 |
+
'init_tb_logger',
|
21 |
+
'init_wandb_logger',
|
22 |
+
'get_root_logger',
|
23 |
+
'get_env_info',
|
24 |
+
# misc.py
|
25 |
+
'set_random_seed',
|
26 |
+
'get_time_str',
|
27 |
+
'mkdir_and_rename',
|
28 |
+
'make_exp_dirs',
|
29 |
+
'scandir',
|
30 |
+
'check_resume',
|
31 |
+
'sizeof_fmt',
|
32 |
+
# diffjpeg
|
33 |
+
'DiffJPEG',
|
34 |
+
# img_process_util
|
35 |
+
'USMSharp',
|
36 |
+
'usm_sharp'
|
37 |
+
]
|
basicsr/utils/color_enhance.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision.transforms import ToTensor, Grayscale
|
2 |
+
|
3 |
+
|
4 |
+
def color_enhacne_blend(x, factor=1.2):
|
5 |
+
x_g = Grayscale(3)(x)
|
6 |
+
out = x_g * (1.0 - factor) + x * factor
|
7 |
+
out[out < 0] = 0
|
8 |
+
out[out > 1] = 1
|
9 |
+
return out
|
basicsr/utils/diffjpeg.py
ADDED
@@ -0,0 +1,515 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Modified from https://github.com/mlomnitz/DiffJPEG
|
3 |
+
|
4 |
+
For images not divisible by 8
|
5 |
+
https://dsp.stackexchange.com/questions/35339/jpeg-dct-padding/35343#35343
|
6 |
+
"""
|
7 |
+
import itertools
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
|
13 |
+
# ------------------------ utils ------------------------#
|
14 |
+
y_table = np.array(
|
15 |
+
[[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60, 55], [14, 13, 16, 24, 40, 57, 69, 56],
|
16 |
+
[14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103, 77], [24, 35, 55, 64, 81, 104, 113, 92],
|
17 |
+
[49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]],
|
18 |
+
dtype=np.float32).T
|
19 |
+
y_table = nn.Parameter(torch.from_numpy(y_table))
|
20 |
+
c_table = np.empty((8, 8), dtype=np.float32)
|
21 |
+
c_table.fill(99)
|
22 |
+
c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66], [24, 26, 56, 99], [47, 66, 99, 99]]).T
|
23 |
+
c_table = nn.Parameter(torch.from_numpy(c_table))
|
24 |
+
|
25 |
+
|
26 |
+
def diff_round(x):
|
27 |
+
""" Differentiable rounding function
|
28 |
+
"""
|
29 |
+
return torch.round(x) + (x - torch.round(x))**3
|
30 |
+
|
31 |
+
|
32 |
+
def quality_to_factor(quality):
|
33 |
+
""" Calculate factor corresponding to quality
|
34 |
+
|
35 |
+
Args:
|
36 |
+
quality(float): Quality for jpeg compression.
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
float: Compression factor.
|
40 |
+
"""
|
41 |
+
if quality < 50:
|
42 |
+
quality = 5000. / quality
|
43 |
+
else:
|
44 |
+
quality = 200. - quality * 2
|
45 |
+
return quality / 100.
|
46 |
+
|
47 |
+
|
48 |
+
# ------------------------ compression ------------------------#
|
49 |
+
class RGB2YCbCrJpeg(nn.Module):
|
50 |
+
""" Converts RGB image to YCbCr
|
51 |
+
"""
|
52 |
+
|
53 |
+
def __init__(self):
|
54 |
+
super(RGB2YCbCrJpeg, self).__init__()
|
55 |
+
matrix = np.array([[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], [0.5, -0.418688, -0.081312]],
|
56 |
+
dtype=np.float32).T
|
57 |
+
self.shift = nn.Parameter(torch.tensor([0., 128., 128.]))
|
58 |
+
self.matrix = nn.Parameter(torch.from_numpy(matrix))
|
59 |
+
|
60 |
+
def forward(self, image):
|
61 |
+
"""
|
62 |
+
Args:
|
63 |
+
image(Tensor): batch x 3 x height x width
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
Tensor: batch x height x width x 3
|
67 |
+
"""
|
68 |
+
image = image.permute(0, 2, 3, 1)
|
69 |
+
result = torch.tensordot(image, self.matrix, dims=1) + self.shift
|
70 |
+
return result.view(image.shape)
|
71 |
+
|
72 |
+
|
73 |
+
class ChromaSubsampling(nn.Module):
|
74 |
+
""" Chroma subsampling on CbCr channels
|
75 |
+
"""
|
76 |
+
|
77 |
+
def __init__(self):
|
78 |
+
super(ChromaSubsampling, self).__init__()
|
79 |
+
|
80 |
+
def forward(self, image):
|
81 |
+
"""
|
82 |
+
Args:
|
83 |
+
image(tensor): batch x height x width x 3
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
y(tensor): batch x height x width
|
87 |
+
cb(tensor): batch x height/2 x width/2
|
88 |
+
cr(tensor): batch x height/2 x width/2
|
89 |
+
"""
|
90 |
+
image_2 = image.permute(0, 3, 1, 2).clone()
|
91 |
+
cb = F.avg_pool2d(image_2[:, 1, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False)
|
92 |
+
cr = F.avg_pool2d(image_2[:, 2, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False)
|
93 |
+
cb = cb.permute(0, 2, 3, 1)
|
94 |
+
cr = cr.permute(0, 2, 3, 1)
|
95 |
+
return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3)
|
96 |
+
|
97 |
+
|
98 |
+
class BlockSplitting(nn.Module):
|
99 |
+
""" Splitting image into patches
|
100 |
+
"""
|
101 |
+
|
102 |
+
def __init__(self):
|
103 |
+
super(BlockSplitting, self).__init__()
|
104 |
+
self.k = 8
|
105 |
+
|
106 |
+
def forward(self, image):
|
107 |
+
"""
|
108 |
+
Args:
|
109 |
+
image(tensor): batch x height x width
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
Tensor: batch x h*w/64 x h x w
|
113 |
+
"""
|
114 |
+
height, _ = image.shape[1:3]
|
115 |
+
batch_size = image.shape[0]
|
116 |
+
image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k)
|
117 |
+
image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
|
118 |
+
return image_transposed.contiguous().view(batch_size, -1, self.k, self.k)
|
119 |
+
|
120 |
+
|
121 |
+
class DCT8x8(nn.Module):
|
122 |
+
""" Discrete Cosine Transformation
|
123 |
+
"""
|
124 |
+
|
125 |
+
def __init__(self):
|
126 |
+
super(DCT8x8, self).__init__()
|
127 |
+
tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
|
128 |
+
for x, y, u, v in itertools.product(range(8), repeat=4):
|
129 |
+
tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos((2 * y + 1) * v * np.pi / 16)
|
130 |
+
alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
|
131 |
+
self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
|
132 |
+
self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float())
|
133 |
+
|
134 |
+
def forward(self, image):
|
135 |
+
"""
|
136 |
+
Args:
|
137 |
+
image(tensor): batch x height x width
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
Tensor: batch x height x width
|
141 |
+
"""
|
142 |
+
image = image - 128
|
143 |
+
result = self.scale * torch.tensordot(image, self.tensor, dims=2)
|
144 |
+
result.view(image.shape)
|
145 |
+
return result
|
146 |
+
|
147 |
+
|
148 |
+
class YQuantize(nn.Module):
|
149 |
+
""" JPEG Quantization for Y channel
|
150 |
+
|
151 |
+
Args:
|
152 |
+
rounding(function): rounding function to use
|
153 |
+
"""
|
154 |
+
|
155 |
+
def __init__(self, rounding):
|
156 |
+
super(YQuantize, self).__init__()
|
157 |
+
self.rounding = rounding
|
158 |
+
self.y_table = y_table
|
159 |
+
|
160 |
+
def forward(self, image, factor=1):
|
161 |
+
"""
|
162 |
+
Args:
|
163 |
+
image(tensor): batch x height x width
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
Tensor: batch x height x width
|
167 |
+
"""
|
168 |
+
if isinstance(factor, (int, float)):
|
169 |
+
image = image.float() / (self.y_table * factor)
|
170 |
+
else:
|
171 |
+
b = factor.size(0)
|
172 |
+
table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
|
173 |
+
image = image.float() / table
|
174 |
+
image = self.rounding(image)
|
175 |
+
return image
|
176 |
+
|
177 |
+
|
178 |
+
class CQuantize(nn.Module):
|
179 |
+
""" JPEG Quantization for CbCr channels
|
180 |
+
|
181 |
+
Args:
|
182 |
+
rounding(function): rounding function to use
|
183 |
+
"""
|
184 |
+
|
185 |
+
def __init__(self, rounding):
|
186 |
+
super(CQuantize, self).__init__()
|
187 |
+
self.rounding = rounding
|
188 |
+
self.c_table = c_table
|
189 |
+
|
190 |
+
def forward(self, image, factor=1):
|
191 |
+
"""
|
192 |
+
Args:
|
193 |
+
image(tensor): batch x height x width
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
Tensor: batch x height x width
|
197 |
+
"""
|
198 |
+
if isinstance(factor, (int, float)):
|
199 |
+
image = image.float() / (self.c_table * factor)
|
200 |
+
else:
|
201 |
+
b = factor.size(0)
|
202 |
+
table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
|
203 |
+
image = image.float() / table
|
204 |
+
image = self.rounding(image)
|
205 |
+
return image
|
206 |
+
|
207 |
+
|
208 |
+
class CompressJpeg(nn.Module):
|
209 |
+
"""Full JPEG compression algorithm
|
210 |
+
|
211 |
+
Args:
|
212 |
+
rounding(function): rounding function to use
|
213 |
+
"""
|
214 |
+
|
215 |
+
def __init__(self, rounding=torch.round):
|
216 |
+
super(CompressJpeg, self).__init__()
|
217 |
+
self.l1 = nn.Sequential(RGB2YCbCrJpeg(), ChromaSubsampling())
|
218 |
+
self.l2 = nn.Sequential(BlockSplitting(), DCT8x8())
|
219 |
+
self.c_quantize = CQuantize(rounding=rounding)
|
220 |
+
self.y_quantize = YQuantize(rounding=rounding)
|
221 |
+
|
222 |
+
def forward(self, image, factor=1):
|
223 |
+
"""
|
224 |
+
Args:
|
225 |
+
image(tensor): batch x 3 x height x width
|
226 |
+
|
227 |
+
Returns:
|
228 |
+
dict(tensor): Compressed tensor with batch x h*w/64 x 8 x 8.
|
229 |
+
"""
|
230 |
+
y, cb, cr = self.l1(image * 255)
|
231 |
+
components = {'y': y, 'cb': cb, 'cr': cr}
|
232 |
+
for k in components.keys():
|
233 |
+
comp = self.l2(components[k])
|
234 |
+
if k in ('cb', 'cr'):
|
235 |
+
comp = self.c_quantize(comp, factor=factor)
|
236 |
+
else:
|
237 |
+
comp = self.y_quantize(comp, factor=factor)
|
238 |
+
|
239 |
+
components[k] = comp
|
240 |
+
|
241 |
+
return components['y'], components['cb'], components['cr']
|
242 |
+
|
243 |
+
|
244 |
+
# ------------------------ decompression ------------------------#
|
245 |
+
|
246 |
+
|
247 |
+
class YDequantize(nn.Module):
|
248 |
+
"""Dequantize Y channel
|
249 |
+
"""
|
250 |
+
|
251 |
+
def __init__(self):
|
252 |
+
super(YDequantize, self).__init__()
|
253 |
+
self.y_table = y_table
|
254 |
+
|
255 |
+
def forward(self, image, factor=1):
|
256 |
+
"""
|
257 |
+
Args:
|
258 |
+
image(tensor): batch x height x width
|
259 |
+
|
260 |
+
Returns:
|
261 |
+
Tensor: batch x height x width
|
262 |
+
"""
|
263 |
+
if isinstance(factor, (int, float)):
|
264 |
+
out = image * (self.y_table * factor)
|
265 |
+
else:
|
266 |
+
b = factor.size(0)
|
267 |
+
table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
|
268 |
+
out = image * table
|
269 |
+
return out
|
270 |
+
|
271 |
+
|
272 |
+
class CDequantize(nn.Module):
|
273 |
+
"""Dequantize CbCr channel
|
274 |
+
"""
|
275 |
+
|
276 |
+
def __init__(self):
|
277 |
+
super(CDequantize, self).__init__()
|
278 |
+
self.c_table = c_table
|
279 |
+
|
280 |
+
def forward(self, image, factor=1):
|
281 |
+
"""
|
282 |
+
Args:
|
283 |
+
image(tensor): batch x height x width
|
284 |
+
|
285 |
+
Returns:
|
286 |
+
Tensor: batch x height x width
|
287 |
+
"""
|
288 |
+
if isinstance(factor, (int, float)):
|
289 |
+
out = image * (self.c_table * factor)
|
290 |
+
else:
|
291 |
+
b = factor.size(0)
|
292 |
+
table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
|
293 |
+
out = image * table
|
294 |
+
return out
|
295 |
+
|
296 |
+
|
297 |
+
class iDCT8x8(nn.Module):
|
298 |
+
"""Inverse discrete Cosine Transformation
|
299 |
+
"""
|
300 |
+
|
301 |
+
def __init__(self):
|
302 |
+
super(iDCT8x8, self).__init__()
|
303 |
+
alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
|
304 |
+
self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float())
|
305 |
+
tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
|
306 |
+
for x, y, u, v in itertools.product(range(8), repeat=4):
|
307 |
+
tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos((2 * v + 1) * y * np.pi / 16)
|
308 |
+
self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
|
309 |
+
|
310 |
+
def forward(self, image):
|
311 |
+
"""
|
312 |
+
Args:
|
313 |
+
image(tensor): batch x height x width
|
314 |
+
|
315 |
+
Returns:
|
316 |
+
Tensor: batch x height x width
|
317 |
+
"""
|
318 |
+
image = image * self.alpha
|
319 |
+
result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128
|
320 |
+
result.view(image.shape)
|
321 |
+
return result
|
322 |
+
|
323 |
+
|
324 |
+
class BlockMerging(nn.Module):
|
325 |
+
"""Merge patches into image
|
326 |
+
"""
|
327 |
+
|
328 |
+
def __init__(self):
|
329 |
+
super(BlockMerging, self).__init__()
|
330 |
+
|
331 |
+
def forward(self, patches, height, width):
|
332 |
+
"""
|
333 |
+
Args:
|
334 |
+
patches(tensor) batch x height*width/64, height x width
|
335 |
+
height(int)
|
336 |
+
width(int)
|
337 |
+
|
338 |
+
Returns:
|
339 |
+
Tensor: batch x height x width
|
340 |
+
"""
|
341 |
+
k = 8
|
342 |
+
batch_size = patches.shape[0]
|
343 |
+
image_reshaped = patches.view(batch_size, height // k, width // k, k, k)
|
344 |
+
image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
|
345 |
+
return image_transposed.contiguous().view(batch_size, height, width)
|
346 |
+
|
347 |
+
|
348 |
+
class ChromaUpsampling(nn.Module):
|
349 |
+
"""Upsample chroma layers
|
350 |
+
"""
|
351 |
+
|
352 |
+
def __init__(self):
|
353 |
+
super(ChromaUpsampling, self).__init__()
|
354 |
+
|
355 |
+
def forward(self, y, cb, cr):
|
356 |
+
"""
|
357 |
+
Args:
|
358 |
+
y(tensor): y channel image
|
359 |
+
cb(tensor): cb channel
|
360 |
+
cr(tensor): cr channel
|
361 |
+
|
362 |
+
Returns:
|
363 |
+
Tensor: batch x height x width x 3
|
364 |
+
"""
|
365 |
+
|
366 |
+
def repeat(x, k=2):
|
367 |
+
height, width = x.shape[1:3]
|
368 |
+
x = x.unsqueeze(-1)
|
369 |
+
x = x.repeat(1, 1, k, k)
|
370 |
+
x = x.view(-1, height * k, width * k)
|
371 |
+
return x
|
372 |
+
|
373 |
+
cb = repeat(cb)
|
374 |
+
cr = repeat(cr)
|
375 |
+
return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3)
|
376 |
+
|
377 |
+
|
378 |
+
class YCbCr2RGBJpeg(nn.Module):
|
379 |
+
"""Converts YCbCr image to RGB JPEG
|
380 |
+
"""
|
381 |
+
|
382 |
+
def __init__(self):
|
383 |
+
super(YCbCr2RGBJpeg, self).__init__()
|
384 |
+
|
385 |
+
matrix = np.array([[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]], dtype=np.float32).T
|
386 |
+
self.shift = nn.Parameter(torch.tensor([0, -128., -128.]))
|
387 |
+
self.matrix = nn.Parameter(torch.from_numpy(matrix))
|
388 |
+
|
389 |
+
def forward(self, image):
|
390 |
+
"""
|
391 |
+
Args:
|
392 |
+
image(tensor): batch x height x width x 3
|
393 |
+
|
394 |
+
Returns:
|
395 |
+
Tensor: batch x 3 x height x width
|
396 |
+
"""
|
397 |
+
result = torch.tensordot(image + self.shift, self.matrix, dims=1)
|
398 |
+
return result.view(image.shape).permute(0, 3, 1, 2)
|
399 |
+
|
400 |
+
|
401 |
+
class DeCompressJpeg(nn.Module):
|
402 |
+
"""Full JPEG decompression algorithm
|
403 |
+
|
404 |
+
Args:
|
405 |
+
rounding(function): rounding function to use
|
406 |
+
"""
|
407 |
+
|
408 |
+
def __init__(self, rounding=torch.round):
|
409 |
+
super(DeCompressJpeg, self).__init__()
|
410 |
+
self.c_dequantize = CDequantize()
|
411 |
+
self.y_dequantize = YDequantize()
|
412 |
+
self.idct = iDCT8x8()
|
413 |
+
self.merging = BlockMerging()
|
414 |
+
self.chroma = ChromaUpsampling()
|
415 |
+
self.colors = YCbCr2RGBJpeg()
|
416 |
+
|
417 |
+
def forward(self, y, cb, cr, imgh, imgw, factor=1):
|
418 |
+
"""
|
419 |
+
Args:
|
420 |
+
compressed(dict(tensor)): batch x h*w/64 x 8 x 8
|
421 |
+
imgh(int)
|
422 |
+
imgw(int)
|
423 |
+
factor(float)
|
424 |
+
|
425 |
+
Returns:
|
426 |
+
Tensor: batch x 3 x height x width
|
427 |
+
"""
|
428 |
+
components = {'y': y, 'cb': cb, 'cr': cr}
|
429 |
+
for k in components.keys():
|
430 |
+
if k in ('cb', 'cr'):
|
431 |
+
comp = self.c_dequantize(components[k], factor=factor)
|
432 |
+
height, width = int(imgh / 2), int(imgw / 2)
|
433 |
+
else:
|
434 |
+
comp = self.y_dequantize(components[k], factor=factor)
|
435 |
+
height, width = imgh, imgw
|
436 |
+
comp = self.idct(comp)
|
437 |
+
components[k] = self.merging(comp, height, width)
|
438 |
+
#
|
439 |
+
image = self.chroma(components['y'], components['cb'], components['cr'])
|
440 |
+
image = self.colors(image)
|
441 |
+
|
442 |
+
image = torch.min(255 * torch.ones_like(image), torch.max(torch.zeros_like(image), image))
|
443 |
+
return image / 255
|
444 |
+
|
445 |
+
|
446 |
+
# ------------------------ main DiffJPEG ------------------------ #
|
447 |
+
|
448 |
+
|
449 |
+
class DiffJPEG(nn.Module):
|
450 |
+
"""This JPEG algorithm result is slightly different from cv2.
|
451 |
+
DiffJPEG supports batch processing.
|
452 |
+
|
453 |
+
Args:
|
454 |
+
differentiable(bool): If True, uses custom differentiable rounding function, if False, uses standard torch.round
|
455 |
+
"""
|
456 |
+
|
457 |
+
def __init__(self, differentiable=True):
|
458 |
+
super(DiffJPEG, self).__init__()
|
459 |
+
if differentiable:
|
460 |
+
rounding = diff_round
|
461 |
+
else:
|
462 |
+
rounding = torch.round
|
463 |
+
|
464 |
+
self.compress = CompressJpeg(rounding=rounding)
|
465 |
+
self.decompress = DeCompressJpeg(rounding=rounding)
|
466 |
+
|
467 |
+
def forward(self, x, quality):
|
468 |
+
"""
|
469 |
+
Args:
|
470 |
+
x (Tensor): Input image, bchw, rgb, [0, 1]
|
471 |
+
quality(float): Quality factor for jpeg compression scheme.
|
472 |
+
"""
|
473 |
+
factor = quality
|
474 |
+
if isinstance(factor, (int, float)):
|
475 |
+
factor = quality_to_factor(factor)
|
476 |
+
else:
|
477 |
+
for i in range(factor.size(0)):
|
478 |
+
factor[i] = quality_to_factor(factor[i])
|
479 |
+
h, w = x.size()[-2:]
|
480 |
+
h_pad, w_pad = 0, 0
|
481 |
+
# why should use 16
|
482 |
+
if h % 16 != 0:
|
483 |
+
h_pad = 16 - h % 16
|
484 |
+
if w % 16 != 0:
|
485 |
+
w_pad = 16 - w % 16
|
486 |
+
x = F.pad(x, (0, w_pad, 0, h_pad), mode='constant', value=0)
|
487 |
+
|
488 |
+
y, cb, cr = self.compress(x, factor=factor)
|
489 |
+
recovered = self.decompress(y, cb, cr, (h + h_pad), (w + w_pad), factor=factor)
|
490 |
+
recovered = recovered[:, :, 0:h, 0:w]
|
491 |
+
return recovered
|
492 |
+
|
493 |
+
|
494 |
+
if __name__ == '__main__':
|
495 |
+
import cv2
|
496 |
+
|
497 |
+
from basicsr.utils import img2tensor, tensor2img
|
498 |
+
|
499 |
+
img_gt = cv2.imread('test.png') / 255.
|
500 |
+
|
501 |
+
# -------------- cv2 -------------- #
|
502 |
+
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 20]
|
503 |
+
_, encimg = cv2.imencode('.jpg', img_gt * 255., encode_param)
|
504 |
+
img_lq = np.float32(cv2.imdecode(encimg, 1))
|
505 |
+
cv2.imwrite('cv2_JPEG_20.png', img_lq)
|
506 |
+
|
507 |
+
# -------------- DiffJPEG -------------- #
|
508 |
+
jpeger = DiffJPEG(differentiable=False).cuda()
|
509 |
+
img_gt = img2tensor(img_gt)
|
510 |
+
img_gt = torch.stack([img_gt, img_gt]).cuda()
|
511 |
+
quality = img_gt.new_tensor([20, 40])
|
512 |
+
out = jpeger(img_gt, quality=quality)
|
513 |
+
|
514 |
+
cv2.imwrite('pt_JPEG_20.png', tensor2img(out[0]))
|
515 |
+
cv2.imwrite('pt_JPEG_40.png', tensor2img(out[1]))
|
basicsr/utils/dist_util.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
|
2 |
+
import functools
|
3 |
+
import os
|
4 |
+
import subprocess
|
5 |
+
import torch
|
6 |
+
import torch.distributed as dist
|
7 |
+
import torch.multiprocessing as mp
|
8 |
+
|
9 |
+
|
10 |
+
def init_dist(launcher, backend='nccl', **kwargs):
|
11 |
+
if mp.get_start_method(allow_none=True) is None:
|
12 |
+
mp.set_start_method('spawn')
|
13 |
+
if launcher == 'pytorch':
|
14 |
+
_init_dist_pytorch(backend, **kwargs)
|
15 |
+
elif launcher == 'slurm':
|
16 |
+
_init_dist_slurm(backend, **kwargs)
|
17 |
+
else:
|
18 |
+
raise ValueError(f'Invalid launcher type: {launcher}')
|
19 |
+
|
20 |
+
|
21 |
+
def _init_dist_pytorch(backend, **kwargs):
|
22 |
+
rank = int(os.environ['RANK'])
|
23 |
+
num_gpus = torch.cuda.device_count()
|
24 |
+
torch.cuda.set_device(rank % num_gpus)
|
25 |
+
dist.init_process_group(backend=backend, **kwargs)
|
26 |
+
|
27 |
+
|
28 |
+
def _init_dist_slurm(backend, port=None):
|
29 |
+
"""Initialize slurm distributed training environment.
|
30 |
+
|
31 |
+
If argument ``port`` is not specified, then the master port will be system
|
32 |
+
environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
|
33 |
+
environment variable, then a default port ``29500`` will be used.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
backend (str): Backend of torch.distributed.
|
37 |
+
port (int, optional): Master port. Defaults to None.
|
38 |
+
"""
|
39 |
+
proc_id = int(os.environ['SLURM_PROCID'])
|
40 |
+
ntasks = int(os.environ['SLURM_NTASKS'])
|
41 |
+
node_list = os.environ['SLURM_NODELIST']
|
42 |
+
num_gpus = torch.cuda.device_count()
|
43 |
+
torch.cuda.set_device(proc_id % num_gpus)
|
44 |
+
addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1')
|
45 |
+
# specify master port
|
46 |
+
if port is not None:
|
47 |
+
os.environ['MASTER_PORT'] = str(port)
|
48 |
+
elif 'MASTER_PORT' in os.environ:
|
49 |
+
pass # use MASTER_PORT in the environment variable
|
50 |
+
else:
|
51 |
+
# 29500 is torch.distributed default port
|
52 |
+
os.environ['MASTER_PORT'] = '29500'
|
53 |
+
os.environ['MASTER_ADDR'] = addr
|
54 |
+
os.environ['WORLD_SIZE'] = str(ntasks)
|
55 |
+
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
|
56 |
+
os.environ['RANK'] = str(proc_id)
|
57 |
+
dist.init_process_group(backend=backend)
|
58 |
+
|
59 |
+
|
60 |
+
def get_dist_info():
|
61 |
+
if dist.is_available():
|
62 |
+
initialized = dist.is_initialized()
|
63 |
+
else:
|
64 |
+
initialized = False
|
65 |
+
if initialized:
|
66 |
+
rank = dist.get_rank()
|
67 |
+
world_size = dist.get_world_size()
|
68 |
+
else:
|
69 |
+
rank = 0
|
70 |
+
world_size = 1
|
71 |
+
return rank, world_size
|
72 |
+
|
73 |
+
|
74 |
+
def master_only(func):
|
75 |
+
|
76 |
+
@functools.wraps(func)
|
77 |
+
def wrapper(*args, **kwargs):
|
78 |
+
rank, _ = get_dist_info()
|
79 |
+
if rank == 0:
|
80 |
+
return func(*args, **kwargs)
|
81 |
+
|
82 |
+
return wrapper
|
basicsr/utils/download_util.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import requests
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
from .misc import sizeof_fmt
|
6 |
+
|
7 |
+
|
8 |
+
def download_file_from_google_drive(file_id, save_path):
|
9 |
+
"""Download files from google drive.
|
10 |
+
|
11 |
+
Ref:
|
12 |
+
https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
|
13 |
+
|
14 |
+
Args:
|
15 |
+
file_id (str): File id.
|
16 |
+
save_path (str): Save path.
|
17 |
+
"""
|
18 |
+
|
19 |
+
session = requests.Session()
|
20 |
+
URL = 'https://docs.google.com/uc?export=download'
|
21 |
+
params = {'id': file_id}
|
22 |
+
|
23 |
+
response = session.get(URL, params=params, stream=True)
|
24 |
+
token = get_confirm_token(response)
|
25 |
+
if token:
|
26 |
+
params['confirm'] = token
|
27 |
+
response = session.get(URL, params=params, stream=True)
|
28 |
+
|
29 |
+
# get file size
|
30 |
+
response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
|
31 |
+
if 'Content-Range' in response_file_size.headers:
|
32 |
+
file_size = int(response_file_size.headers['Content-Range'].split('/')[1])
|
33 |
+
else:
|
34 |
+
file_size = None
|
35 |
+
|
36 |
+
save_response_content(response, save_path, file_size)
|
37 |
+
|
38 |
+
|
39 |
+
def get_confirm_token(response):
|
40 |
+
for key, value in response.cookies.items():
|
41 |
+
if key.startswith('download_warning'):
|
42 |
+
return value
|
43 |
+
return None
|
44 |
+
|
45 |
+
|
46 |
+
def save_response_content(response, destination, file_size=None, chunk_size=32768):
|
47 |
+
if file_size is not None:
|
48 |
+
pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
|
49 |
+
|
50 |
+
readable_file_size = sizeof_fmt(file_size)
|
51 |
+
else:
|
52 |
+
pbar = None
|
53 |
+
|
54 |
+
with open(destination, 'wb') as f:
|
55 |
+
downloaded_size = 0
|
56 |
+
for chunk in response.iter_content(chunk_size):
|
57 |
+
downloaded_size += chunk_size
|
58 |
+
if pbar is not None:
|
59 |
+
pbar.update(1)
|
60 |
+
pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}')
|
61 |
+
if chunk: # filter out keep-alive new chunks
|
62 |
+
f.write(chunk)
|
63 |
+
if pbar is not None:
|
64 |
+
pbar.close()
|
basicsr/utils/face_util.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
from skimage import transform as trans
|
6 |
+
|
7 |
+
from basicsr.utils import imwrite
|
8 |
+
|
9 |
+
try:
|
10 |
+
import dlib
|
11 |
+
except ImportError:
|
12 |
+
print('Please install dlib before testing face restoration.' 'Reference: https://github.com/davisking/dlib')
|
13 |
+
|
14 |
+
|
15 |
+
class FaceRestorationHelper(object):
|
16 |
+
"""Helper for the face restoration pipeline."""
|
17 |
+
|
18 |
+
def __init__(self, upscale_factor, face_size=512):
|
19 |
+
self.upscale_factor = upscale_factor
|
20 |
+
self.face_size = (face_size, face_size)
|
21 |
+
|
22 |
+
# standard 5 landmarks for FFHQ faces with 1024 x 1024
|
23 |
+
self.face_template = np.array([[686.77227723, 488.62376238], [586.77227723, 493.59405941],
|
24 |
+
[337.91089109, 488.38613861], [437.95049505, 493.51485149],
|
25 |
+
[513.58415842, 678.5049505]])
|
26 |
+
self.face_template = self.face_template / (1024 // face_size)
|
27 |
+
# for estimation the 2D similarity transformation
|
28 |
+
self.similarity_trans = trans.SimilarityTransform()
|
29 |
+
|
30 |
+
self.all_landmarks_5 = []
|
31 |
+
self.all_landmarks_68 = []
|
32 |
+
self.affine_matrices = []
|
33 |
+
self.inverse_affine_matrices = []
|
34 |
+
self.cropped_faces = []
|
35 |
+
self.restored_faces = []
|
36 |
+
self.save_png = True
|
37 |
+
|
38 |
+
def init_dlib(self, detection_path, landmark5_path, landmark68_path):
|
39 |
+
"""Initialize the dlib detectors and predictors."""
|
40 |
+
self.face_detector = dlib.cnn_face_detection_model_v1(detection_path)
|
41 |
+
self.shape_predictor_5 = dlib.shape_predictor(landmark5_path)
|
42 |
+
self.shape_predictor_68 = dlib.shape_predictor(landmark68_path)
|
43 |
+
|
44 |
+
def free_dlib_gpu_memory(self):
|
45 |
+
del self.face_detector
|
46 |
+
del self.shape_predictor_5
|
47 |
+
del self.shape_predictor_68
|
48 |
+
|
49 |
+
def read_input_image(self, img_path):
|
50 |
+
# self.input_img is Numpy array, (h, w, c) with RGB order
|
51 |
+
self.input_img = dlib.load_rgb_image(img_path)
|
52 |
+
|
53 |
+
def detect_faces(self, img_path, upsample_num_times=1, only_keep_largest=False):
|
54 |
+
"""
|
55 |
+
Args:
|
56 |
+
img_path (str): Image path.
|
57 |
+
upsample_num_times (int): Upsamples the image before running the
|
58 |
+
face detector
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
int: Number of detected faces.
|
62 |
+
"""
|
63 |
+
self.read_input_image(img_path)
|
64 |
+
det_faces = self.face_detector(self.input_img, upsample_num_times)
|
65 |
+
if len(det_faces) == 0:
|
66 |
+
print('No face detected. Try to increase upsample_num_times.')
|
67 |
+
else:
|
68 |
+
if only_keep_largest:
|
69 |
+
print('Detect several faces and only keep the largest.')
|
70 |
+
face_areas = []
|
71 |
+
for i in range(len(det_faces)):
|
72 |
+
face_area = (det_faces[i].rect.right() - det_faces[i].rect.left()) * (
|
73 |
+
det_faces[i].rect.bottom() - det_faces[i].rect.top())
|
74 |
+
face_areas.append(face_area)
|
75 |
+
largest_idx = face_areas.index(max(face_areas))
|
76 |
+
self.det_faces = [det_faces[largest_idx]]
|
77 |
+
else:
|
78 |
+
self.det_faces = det_faces
|
79 |
+
return len(self.det_faces)
|
80 |
+
|
81 |
+
def get_face_landmarks_5(self):
|
82 |
+
for face in self.det_faces:
|
83 |
+
shape = self.shape_predictor_5(self.input_img, face.rect)
|
84 |
+
landmark = np.array([[part.x, part.y] for part in shape.parts()])
|
85 |
+
self.all_landmarks_5.append(landmark)
|
86 |
+
return len(self.all_landmarks_5)
|
87 |
+
|
88 |
+
def get_face_landmarks_68(self):
|
89 |
+
"""Get 68 densemarks for cropped images.
|
90 |
+
|
91 |
+
Should only have one face at most in the cropped image.
|
92 |
+
"""
|
93 |
+
num_detected_face = 0
|
94 |
+
for idx, face in enumerate(self.cropped_faces):
|
95 |
+
# face detection
|
96 |
+
det_face = self.face_detector(face, 1) # TODO: can we remove it?
|
97 |
+
if len(det_face) == 0:
|
98 |
+
print(f'Cannot find faces in cropped image with index {idx}.')
|
99 |
+
self.all_landmarks_68.append(None)
|
100 |
+
else:
|
101 |
+
if len(det_face) > 1:
|
102 |
+
print('Detect several faces in the cropped face. Use the '
|
103 |
+
' largest one. Note that it will also cause overlap '
|
104 |
+
'during paste_faces_to_input_image.')
|
105 |
+
face_areas = []
|
106 |
+
for i in range(len(det_face)):
|
107 |
+
face_area = (det_face[i].rect.right() - det_face[i].rect.left()) * (
|
108 |
+
det_face[i].rect.bottom() - det_face[i].rect.top())
|
109 |
+
face_areas.append(face_area)
|
110 |
+
largest_idx = face_areas.index(max(face_areas))
|
111 |
+
face_rect = det_face[largest_idx].rect
|
112 |
+
else:
|
113 |
+
face_rect = det_face[0].rect
|
114 |
+
shape = self.shape_predictor_68(face, face_rect)
|
115 |
+
landmark = np.array([[part.x, part.y] for part in shape.parts()])
|
116 |
+
self.all_landmarks_68.append(landmark)
|
117 |
+
num_detected_face += 1
|
118 |
+
|
119 |
+
return num_detected_face
|
120 |
+
|
121 |
+
def warp_crop_faces(self, save_cropped_path=None, save_inverse_affine_path=None):
|
122 |
+
"""Get affine matrix, warp and cropped faces.
|
123 |
+
|
124 |
+
Also get inverse affine matrix for post-processing.
|
125 |
+
"""
|
126 |
+
for idx, landmark in enumerate(self.all_landmarks_5):
|
127 |
+
# use 5 landmarks to get affine matrix
|
128 |
+
self.similarity_trans.estimate(landmark, self.face_template)
|
129 |
+
affine_matrix = self.similarity_trans.params[0:2, :]
|
130 |
+
self.affine_matrices.append(affine_matrix)
|
131 |
+
# warp and crop faces
|
132 |
+
cropped_face = cv2.warpAffine(self.input_img, affine_matrix, self.face_size)
|
133 |
+
self.cropped_faces.append(cropped_face)
|
134 |
+
# save the cropped face
|
135 |
+
if save_cropped_path is not None:
|
136 |
+
path, ext = os.path.splitext(save_cropped_path)
|
137 |
+
if self.save_png:
|
138 |
+
save_path = f'{path}_{idx:02d}.png'
|
139 |
+
else:
|
140 |
+
save_path = f'{path}_{idx:02d}{ext}'
|
141 |
+
|
142 |
+
imwrite(cv2.cvtColor(cropped_face, cv2.COLOR_RGB2BGR), save_path)
|
143 |
+
|
144 |
+
# get inverse affine matrix
|
145 |
+
self.similarity_trans.estimate(self.face_template, landmark * self.upscale_factor)
|
146 |
+
inverse_affine = self.similarity_trans.params[0:2, :]
|
147 |
+
self.inverse_affine_matrices.append(inverse_affine)
|
148 |
+
# save inverse affine matrices
|
149 |
+
if save_inverse_affine_path is not None:
|
150 |
+
path, _ = os.path.splitext(save_inverse_affine_path)
|
151 |
+
save_path = f'{path}_{idx:02d}.pth'
|
152 |
+
torch.save(inverse_affine, save_path)
|
153 |
+
|
154 |
+
def add_restored_face(self, face):
|
155 |
+
self.restored_faces.append(face)
|
156 |
+
|
157 |
+
def paste_faces_to_input_image(self, save_path):
|
158 |
+
# operate in the BGR order
|
159 |
+
input_img = cv2.cvtColor(self.input_img, cv2.COLOR_RGB2BGR)
|
160 |
+
h, w, _ = input_img.shape
|
161 |
+
h_up, w_up = h * self.upscale_factor, w * self.upscale_factor
|
162 |
+
# simply resize the background
|
163 |
+
upsample_img = cv2.resize(input_img, (w_up, h_up))
|
164 |
+
assert len(self.restored_faces) == len(
|
165 |
+
self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.')
|
166 |
+
for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices):
|
167 |
+
inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
|
168 |
+
mask = np.ones((*self.face_size, 3), dtype=np.float32)
|
169 |
+
inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
|
170 |
+
# remove the black borders
|
171 |
+
inv_mask_erosion = cv2.erode(inv_mask, np.ones((2 * self.upscale_factor, 2 * self.upscale_factor),
|
172 |
+
np.uint8))
|
173 |
+
inv_restored_remove_border = inv_mask_erosion * inv_restored
|
174 |
+
total_face_area = np.sum(inv_mask_erosion) // 3
|
175 |
+
# compute the fusion edge based on the area of face
|
176 |
+
w_edge = int(total_face_area**0.5) // 20
|
177 |
+
erosion_radius = w_edge * 2
|
178 |
+
inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
|
179 |
+
blur_size = w_edge * 2
|
180 |
+
inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
|
181 |
+
upsample_img = inv_soft_mask * inv_restored_remove_border + (1 - inv_soft_mask) * upsample_img
|
182 |
+
if self.save_png:
|
183 |
+
save_path = save_path.replace('.jpg', '.png').replace('.jpeg', '.png')
|
184 |
+
imwrite(upsample_img.astype(np.uint8), save_path)
|
185 |
+
|
186 |
+
def clean_all(self):
|
187 |
+
self.all_landmarks_5 = []
|
188 |
+
self.all_landmarks_68 = []
|
189 |
+
self.restored_faces = []
|
190 |
+
self.affine_matrices = []
|
191 |
+
self.cropped_faces = []
|
192 |
+
self.inverse_affine_matrices = []
|
basicsr/utils/file_client.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
|
2 |
+
from abc import ABCMeta, abstractmethod
|
3 |
+
|
4 |
+
|
5 |
+
class BaseStorageBackend(metaclass=ABCMeta):
|
6 |
+
"""Abstract class of storage backends.
|
7 |
+
|
8 |
+
All backends need to implement two apis: ``get()`` and ``get_text()``.
|
9 |
+
``get()`` reads the file as a byte stream and ``get_text()`` reads the file
|
10 |
+
as texts.
|
11 |
+
"""
|
12 |
+
|
13 |
+
@abstractmethod
|
14 |
+
def get(self, filepath):
|
15 |
+
pass
|
16 |
+
|
17 |
+
@abstractmethod
|
18 |
+
def get_text(self, filepath):
|
19 |
+
pass
|
20 |
+
|
21 |
+
|
22 |
+
class MemcachedBackend(BaseStorageBackend):
|
23 |
+
"""Memcached storage backend.
|
24 |
+
|
25 |
+
Attributes:
|
26 |
+
server_list_cfg (str): Config file for memcached server list.
|
27 |
+
client_cfg (str): Config file for memcached client.
|
28 |
+
sys_path (str | None): Additional path to be appended to `sys.path`.
|
29 |
+
Default: None.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self, server_list_cfg, client_cfg, sys_path=None):
|
33 |
+
if sys_path is not None:
|
34 |
+
import sys
|
35 |
+
sys.path.append(sys_path)
|
36 |
+
try:
|
37 |
+
import mc
|
38 |
+
except ImportError:
|
39 |
+
raise ImportError('Please install memcached to enable MemcachedBackend.')
|
40 |
+
|
41 |
+
self.server_list_cfg = server_list_cfg
|
42 |
+
self.client_cfg = client_cfg
|
43 |
+
self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
|
44 |
+
# mc.pyvector servers as a point which points to a memory cache
|
45 |
+
self._mc_buffer = mc.pyvector()
|
46 |
+
|
47 |
+
def get(self, filepath):
|
48 |
+
filepath = str(filepath)
|
49 |
+
import mc
|
50 |
+
self._client.Get(filepath, self._mc_buffer)
|
51 |
+
value_buf = mc.ConvertBuffer(self._mc_buffer)
|
52 |
+
return value_buf
|
53 |
+
|
54 |
+
def get_text(self, filepath):
|
55 |
+
raise NotImplementedError
|
56 |
+
|
57 |
+
|
58 |
+
class HardDiskBackend(BaseStorageBackend):
|
59 |
+
"""Raw hard disks storage backend."""
|
60 |
+
|
61 |
+
def get(self, filepath):
|
62 |
+
filepath = str(filepath)
|
63 |
+
with open(filepath, 'rb') as f:
|
64 |
+
value_buf = f.read()
|
65 |
+
return value_buf
|
66 |
+
|
67 |
+
def get_text(self, filepath):
|
68 |
+
filepath = str(filepath)
|
69 |
+
with open(filepath, 'r') as f:
|
70 |
+
value_buf = f.read()
|
71 |
+
return value_buf
|
72 |
+
|
73 |
+
|
74 |
+
class LmdbBackend(BaseStorageBackend):
|
75 |
+
"""Lmdb storage backend.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
db_paths (str | list[str]): Lmdb database paths.
|
79 |
+
client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
|
80 |
+
readonly (bool, optional): Lmdb environment parameter. If True,
|
81 |
+
disallow any write operations. Default: True.
|
82 |
+
lock (bool, optional): Lmdb environment parameter. If False, when
|
83 |
+
concurrent access occurs, do not lock the database. Default: False.
|
84 |
+
readahead (bool, optional): Lmdb environment parameter. If False,
|
85 |
+
disable the OS filesystem readahead mechanism, which may improve
|
86 |
+
random read performance when a database is larger than RAM.
|
87 |
+
Default: False.
|
88 |
+
|
89 |
+
Attributes:
|
90 |
+
db_paths (list): Lmdb database path.
|
91 |
+
_client (list): A list of several lmdb envs.
|
92 |
+
"""
|
93 |
+
|
94 |
+
def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
|
95 |
+
try:
|
96 |
+
import lmdb
|
97 |
+
except ImportError:
|
98 |
+
raise ImportError('Please install lmdb to enable LmdbBackend.')
|
99 |
+
|
100 |
+
if isinstance(client_keys, str):
|
101 |
+
client_keys = [client_keys]
|
102 |
+
|
103 |
+
if isinstance(db_paths, list):
|
104 |
+
self.db_paths = [str(v) for v in db_paths]
|
105 |
+
elif isinstance(db_paths, str):
|
106 |
+
self.db_paths = [str(db_paths)]
|
107 |
+
assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
|
108 |
+
f'but received {len(client_keys)} and {len(self.db_paths)}.')
|
109 |
+
|
110 |
+
self._client = {}
|
111 |
+
for client, path in zip(client_keys, self.db_paths):
|
112 |
+
self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
|
113 |
+
|
114 |
+
def get(self, filepath, client_key):
|
115 |
+
"""Get values according to the filepath from one lmdb named client_key.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
filepath (str | obj:`Path`): Here, filepath is the lmdb key.
|
119 |
+
client_key (str): Used for distinguishing different lmdb envs.
|
120 |
+
"""
|
121 |
+
filepath = str(filepath)
|
122 |
+
assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.')
|
123 |
+
client = self._client[client_key]
|
124 |
+
with client.begin(write=False) as txn:
|
125 |
+
value_buf = txn.get(filepath.encode('ascii'))
|
126 |
+
return value_buf
|
127 |
+
|
128 |
+
def get_text(self, filepath):
|
129 |
+
raise NotImplementedError
|
130 |
+
|
131 |
+
|
132 |
+
class FileClient(object):
|
133 |
+
"""A general file client to access files in different backend.
|
134 |
+
|
135 |
+
The client loads a file or text in a specified backend from its path
|
136 |
+
and return it as a binary file. it can also register other backend
|
137 |
+
accessor with a given name and backend class.
|
138 |
+
|
139 |
+
Attributes:
|
140 |
+
backend (str): The storage backend type. Options are "disk",
|
141 |
+
"memcached" and "lmdb".
|
142 |
+
client (:obj:`BaseStorageBackend`): The backend object.
|
143 |
+
"""
|
144 |
+
|
145 |
+
_backends = {
|
146 |
+
'disk': HardDiskBackend,
|
147 |
+
'memcached': MemcachedBackend,
|
148 |
+
'lmdb': LmdbBackend,
|
149 |
+
}
|
150 |
+
|
151 |
+
def __init__(self, backend='disk', **kwargs):
|
152 |
+
if backend not in self._backends:
|
153 |
+
raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
|
154 |
+
f' are {list(self._backends.keys())}')
|
155 |
+
self.backend = backend
|
156 |
+
self.client = self._backends[backend](**kwargs)
|
157 |
+
|
158 |
+
def get(self, filepath, client_key='default'):
|
159 |
+
# client_key is used only for lmdb, where different fileclients have
|
160 |
+
# different lmdb environments.
|
161 |
+
if self.backend == 'lmdb':
|
162 |
+
return self.client.get(filepath, client_key)
|
163 |
+
else:
|
164 |
+
return self.client.get(filepath)
|
165 |
+
|
166 |
+
def get_text(self, filepath):
|
167 |
+
return self.client.get_text(filepath)
|
basicsr/utils/flow_util.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
|
6 |
+
|
7 |
+
def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs):
|
8 |
+
"""Read an optical flow map.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
flow_path (ndarray or str): Flow path.
|
12 |
+
quantize (bool): whether to read quantized pair, if set to True,
|
13 |
+
remaining args will be passed to :func:`dequantize_flow`.
|
14 |
+
concat_axis (int): The axis that dx and dy are concatenated,
|
15 |
+
can be either 0 or 1. Ignored if quantize is False.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
ndarray: Optical flow represented as a (h, w, 2) numpy array
|
19 |
+
"""
|
20 |
+
if quantize:
|
21 |
+
assert concat_axis in [0, 1]
|
22 |
+
cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED)
|
23 |
+
if cat_flow.ndim != 2:
|
24 |
+
raise IOError(f'{flow_path} is not a valid quantized flow file, its dimension is {cat_flow.ndim}.')
|
25 |
+
assert cat_flow.shape[concat_axis] % 2 == 0
|
26 |
+
dx, dy = np.split(cat_flow, 2, axis=concat_axis)
|
27 |
+
flow = dequantize_flow(dx, dy, *args, **kwargs)
|
28 |
+
else:
|
29 |
+
with open(flow_path, 'rb') as f:
|
30 |
+
try:
|
31 |
+
header = f.read(4).decode('utf-8')
|
32 |
+
except Exception:
|
33 |
+
raise IOError(f'Invalid flow file: {flow_path}')
|
34 |
+
else:
|
35 |
+
if header != 'PIEH':
|
36 |
+
raise IOError(f'Invalid flow file: {flow_path}, ' 'header does not contain PIEH')
|
37 |
+
|
38 |
+
w = np.fromfile(f, np.int32, 1).squeeze()
|
39 |
+
h = np.fromfile(f, np.int32, 1).squeeze()
|
40 |
+
flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2))
|
41 |
+
|
42 |
+
return flow.astype(np.float32)
|
43 |
+
|
44 |
+
|
45 |
+
def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
|
46 |
+
"""Write optical flow to file.
|
47 |
+
|
48 |
+
If the flow is not quantized, it will be saved as a .flo file losslessly,
|
49 |
+
otherwise a jpeg image which is lossy but of much smaller size. (dx and dy
|
50 |
+
will be concatenated horizontally into a single image if quantize is True.)
|
51 |
+
|
52 |
+
Args:
|
53 |
+
flow (ndarray): (h, w, 2) array of optical flow.
|
54 |
+
filename (str): Output filepath.
|
55 |
+
quantize (bool): Whether to quantize the flow and save it to 2 jpeg
|
56 |
+
images. If set to True, remaining args will be passed to
|
57 |
+
:func:`quantize_flow`.
|
58 |
+
concat_axis (int): The axis that dx and dy are concatenated,
|
59 |
+
can be either 0 or 1. Ignored if quantize is False.
|
60 |
+
"""
|
61 |
+
if not quantize:
|
62 |
+
with open(filename, 'wb') as f:
|
63 |
+
f.write('PIEH'.encode('utf-8'))
|
64 |
+
np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
|
65 |
+
flow = flow.astype(np.float32)
|
66 |
+
flow.tofile(f)
|
67 |
+
f.flush()
|
68 |
+
else:
|
69 |
+
assert concat_axis in [0, 1]
|
70 |
+
dx, dy = quantize_flow(flow, *args, **kwargs)
|
71 |
+
dxdy = np.concatenate((dx, dy), axis=concat_axis)
|
72 |
+
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
73 |
+
cv2.imwrite(filename, dxdy)
|
74 |
+
|
75 |
+
|
76 |
+
def quantize_flow(flow, max_val=0.02, norm=True):
|
77 |
+
"""Quantize flow to [0, 255].
|
78 |
+
|
79 |
+
After this step, the size of flow will be much smaller, and can be
|
80 |
+
dumped as jpeg images.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
flow (ndarray): (h, w, 2) array of optical flow.
|
84 |
+
max_val (float): Maximum value of flow, values beyond
|
85 |
+
[-max_val, max_val] will be truncated.
|
86 |
+
norm (bool): Whether to divide flow values by image width/height.
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
tuple[ndarray]: Quantized dx and dy.
|
90 |
+
"""
|
91 |
+
h, w, _ = flow.shape
|
92 |
+
dx = flow[..., 0]
|
93 |
+
dy = flow[..., 1]
|
94 |
+
if norm:
|
95 |
+
dx = dx / w # avoid inplace operations
|
96 |
+
dy = dy / h
|
97 |
+
# use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
|
98 |
+
flow_comps = [quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]]
|
99 |
+
return tuple(flow_comps)
|
100 |
+
|
101 |
+
|
102 |
+
def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
|
103 |
+
"""Recover from quantized flow.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
dx (ndarray): Quantized dx.
|
107 |
+
dy (ndarray): Quantized dy.
|
108 |
+
max_val (float): Maximum value used when quantizing.
|
109 |
+
denorm (bool): Whether to multiply flow values with width/height.
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
ndarray: Dequantized flow.
|
113 |
+
"""
|
114 |
+
assert dx.shape == dy.shape
|
115 |
+
assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
|
116 |
+
|
117 |
+
dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
|
118 |
+
|
119 |
+
if denorm:
|
120 |
+
dx *= dx.shape[1]
|
121 |
+
dy *= dx.shape[0]
|
122 |
+
flow = np.dstack((dx, dy))
|
123 |
+
return flow
|
124 |
+
|
125 |
+
|
126 |
+
def quantize(arr, min_val, max_val, levels, dtype=np.int64):
|
127 |
+
"""Quantize an array of (-inf, inf) to [0, levels-1].
|
128 |
+
|
129 |
+
Args:
|
130 |
+
arr (ndarray): Input array.
|
131 |
+
min_val (scalar): Minimum value to be clipped.
|
132 |
+
max_val (scalar): Maximum value to be clipped.
|
133 |
+
levels (int): Quantization levels.
|
134 |
+
dtype (np.type): The type of the quantized array.
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
tuple: Quantized array.
|
138 |
+
"""
|
139 |
+
if not (isinstance(levels, int) and levels > 1):
|
140 |
+
raise ValueError(f'levels must be a positive integer, but got {levels}')
|
141 |
+
if min_val >= max_val:
|
142 |
+
raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})')
|
143 |
+
|
144 |
+
arr = np.clip(arr, min_val, max_val) - min_val
|
145 |
+
quantized_arr = np.minimum(np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
|
146 |
+
|
147 |
+
return quantized_arr
|
148 |
+
|
149 |
+
|
150 |
+
def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
|
151 |
+
"""Dequantize an array.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
arr (ndarray): Input array.
|
155 |
+
min_val (scalar): Minimum value to be clipped.
|
156 |
+
max_val (scalar): Maximum value to be clipped.
|
157 |
+
levels (int): Quantization levels.
|
158 |
+
dtype (np.type): The type of the dequantized array.
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
tuple: Dequantized array.
|
162 |
+
"""
|
163 |
+
if not (isinstance(levels, int) and levels > 1):
|
164 |
+
raise ValueError(f'levels must be a positive integer, but got {levels}')
|
165 |
+
if min_val >= max_val:
|
166 |
+
raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})')
|
167 |
+
|
168 |
+
dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - min_val) / levels + min_val
|
169 |
+
|
170 |
+
return dequantized_arr
|
basicsr/utils/img_process_util.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
|
7 |
+
def filter2D(img, kernel):
|
8 |
+
"""PyTorch version of cv2.filter2D
|
9 |
+
|
10 |
+
Args:
|
11 |
+
img (Tensor): (b, c, h, w)
|
12 |
+
kernel (Tensor): (b, k, k)
|
13 |
+
"""
|
14 |
+
k = kernel.size(-1)
|
15 |
+
b, c, h, w = img.size()
|
16 |
+
if k % 2 == 1:
|
17 |
+
img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect')
|
18 |
+
else:
|
19 |
+
raise ValueError('Wrong kernel size')
|
20 |
+
|
21 |
+
ph, pw = img.size()[-2:]
|
22 |
+
|
23 |
+
if kernel.size(0) == 1:
|
24 |
+
# apply the same kernel to all batch images
|
25 |
+
img = img.view(b * c, 1, ph, pw)
|
26 |
+
kernel = kernel.view(1, 1, k, k)
|
27 |
+
return F.conv2d(img, kernel, padding=0).view(b, c, h, w)
|
28 |
+
else:
|
29 |
+
img = img.view(1, b * c, ph, pw)
|
30 |
+
kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k)
|
31 |
+
return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w)
|
32 |
+
|
33 |
+
|
34 |
+
def usm_sharp(img, weight=0.5, radius=50, threshold=10):
|
35 |
+
"""USM sharpening.
|
36 |
+
|
37 |
+
Input image: I; Blurry image: B.
|
38 |
+
1. sharp = I + weight * (I - B)
|
39 |
+
2. Mask = 1 if abs(I - B) > threshold, else: 0
|
40 |
+
3. Blur mask:
|
41 |
+
4. Out = Mask * sharp + (1 - Mask) * I
|
42 |
+
|
43 |
+
|
44 |
+
Args:
|
45 |
+
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
|
46 |
+
weight (float): Sharp weight. Default: 1.
|
47 |
+
radius (float): Kernel size of Gaussian blur. Default: 50.
|
48 |
+
threshold (int):
|
49 |
+
"""
|
50 |
+
if radius % 2 == 0:
|
51 |
+
radius += 1
|
52 |
+
blur = cv2.GaussianBlur(img, (radius, radius), 0)
|
53 |
+
residual = img - blur
|
54 |
+
mask = np.abs(residual) * 255 > threshold
|
55 |
+
mask = mask.astype('float32')
|
56 |
+
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
|
57 |
+
|
58 |
+
sharp = img + weight * residual
|
59 |
+
sharp = np.clip(sharp, 0, 1)
|
60 |
+
return soft_mask * sharp + (1 - soft_mask) * img
|
61 |
+
|
62 |
+
|
63 |
+
class USMSharp(torch.nn.Module):
|
64 |
+
|
65 |
+
def __init__(self, radius=50, sigma=0):
|
66 |
+
super(USMSharp, self).__init__()
|
67 |
+
if radius % 2 == 0:
|
68 |
+
radius += 1
|
69 |
+
self.radius = radius
|
70 |
+
kernel = cv2.getGaussianKernel(radius, sigma)
|
71 |
+
kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0)
|
72 |
+
self.register_buffer('kernel', kernel)
|
73 |
+
|
74 |
+
def forward(self, img, weight=0.5, threshold=10):
|
75 |
+
blur = filter2D(img, self.kernel)
|
76 |
+
residual = img - blur
|
77 |
+
|
78 |
+
mask = torch.abs(residual) * 255 > threshold
|
79 |
+
mask = mask.float()
|
80 |
+
soft_mask = filter2D(mask, self.kernel)
|
81 |
+
sharp = img + weight * residual
|
82 |
+
sharp = torch.clip(sharp, 0, 1)
|
83 |
+
return soft_mask * sharp + (1 - soft_mask) * img
|
basicsr/utils/img_util.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
from torchvision.utils import make_grid
|
7 |
+
|
8 |
+
|
9 |
+
def img2tensor(imgs, bgr2rgb=True, float32=True):
|
10 |
+
"""Numpy array to tensor.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
imgs (list[ndarray] | ndarray): Input images.
|
14 |
+
bgr2rgb (bool): Whether to change bgr to rgb.
|
15 |
+
float32 (bool): Whether to change to float32.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
list[tensor] | tensor: Tensor images. If returned results only have
|
19 |
+
one element, just return tensor.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def _totensor(img, bgr2rgb, float32):
|
23 |
+
if img.shape[2] == 3 and bgr2rgb:
|
24 |
+
if img.dtype == 'float64':
|
25 |
+
img = img.astype('float32')
|
26 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
27 |
+
img = torch.from_numpy(img.transpose(2, 0, 1))
|
28 |
+
if float32:
|
29 |
+
img = img.float()
|
30 |
+
return img
|
31 |
+
|
32 |
+
if isinstance(imgs, list):
|
33 |
+
return [_totensor(img, bgr2rgb, float32) for img in imgs]
|
34 |
+
else:
|
35 |
+
return _totensor(imgs, bgr2rgb, float32)
|
36 |
+
|
37 |
+
|
38 |
+
def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
|
39 |
+
"""Convert torch Tensors into image numpy arrays.
|
40 |
+
|
41 |
+
After clamping to [min, max], values will be normalized to [0, 1].
|
42 |
+
|
43 |
+
Args:
|
44 |
+
tensor (Tensor or list[Tensor]): Accept shapes:
|
45 |
+
1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
|
46 |
+
2) 3D Tensor of shape (3/1 x H x W);
|
47 |
+
3) 2D Tensor of shape (H x W).
|
48 |
+
Tensor channel should be in RGB order.
|
49 |
+
rgb2bgr (bool): Whether to change rgb to bgr.
|
50 |
+
out_type (numpy type): output types. If ``np.uint8``, transform outputs
|
51 |
+
to uint8 type with range [0, 255]; otherwise, float type with
|
52 |
+
range [0, 1]. Default: ``np.uint8``.
|
53 |
+
min_max (tuple[int]): min and max values for clamp.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
(Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
|
57 |
+
shape (H x W). The channel order is BGR.
|
58 |
+
"""
|
59 |
+
if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
|
60 |
+
raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
|
61 |
+
|
62 |
+
if torch.is_tensor(tensor):
|
63 |
+
tensor = [tensor]
|
64 |
+
result = []
|
65 |
+
for _tensor in tensor:
|
66 |
+
_tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
|
67 |
+
_tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
|
68 |
+
|
69 |
+
n_dim = _tensor.dim()
|
70 |
+
if n_dim == 4:
|
71 |
+
img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
|
72 |
+
img_np = img_np.transpose(1, 2, 0)
|
73 |
+
if rgb2bgr:
|
74 |
+
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
75 |
+
elif n_dim == 3:
|
76 |
+
img_np = _tensor.numpy()
|
77 |
+
img_np = img_np.transpose(1, 2, 0)
|
78 |
+
if img_np.shape[2] == 1: # gray image
|
79 |
+
img_np = np.squeeze(img_np, axis=2)
|
80 |
+
else:
|
81 |
+
if rgb2bgr:
|
82 |
+
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
83 |
+
elif n_dim == 2:
|
84 |
+
img_np = _tensor.numpy()
|
85 |
+
else:
|
86 |
+
raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
|
87 |
+
if out_type == np.uint8:
|
88 |
+
# Unlike MATLAB, numpy.unit8() WILL NOT round by default.
|
89 |
+
img_np = (img_np * 255.0).round()
|
90 |
+
img_np = img_np.astype(out_type)
|
91 |
+
result.append(img_np)
|
92 |
+
if len(result) == 1:
|
93 |
+
result = result[0]
|
94 |
+
return result
|
95 |
+
|
96 |
+
|
97 |
+
def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
|
98 |
+
"""This implementation is slightly faster than tensor2img.
|
99 |
+
It now only supports torch tensor with shape (1, c, h, w).
|
100 |
+
|
101 |
+
Args:
|
102 |
+
tensor (Tensor): Now only support torch tensor with (1, c, h, w).
|
103 |
+
rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
|
104 |
+
min_max (tuple[int]): min and max values for clamp.
|
105 |
+
"""
|
106 |
+
output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
|
107 |
+
output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
|
108 |
+
output = output.type(torch.uint8).cpu().numpy()
|
109 |
+
if rgb2bgr:
|
110 |
+
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
111 |
+
return output
|
112 |
+
|
113 |
+
|
114 |
+
def imfrombytes(content, flag='color', float32=False):
|
115 |
+
"""Read an image from bytes.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
content (bytes): Image bytes got from files or other streams.
|
119 |
+
flag (str): Flags specifying the color type of a loaded image,
|
120 |
+
candidates are `color`, `grayscale` and `unchanged`.
|
121 |
+
float32 (bool): Whether to change to float32., If True, will also norm
|
122 |
+
to [0, 1]. Default: False.
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
ndarray: Loaded image array.
|
126 |
+
"""
|
127 |
+
img_np = np.frombuffer(content, np.uint8)
|
128 |
+
imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
|
129 |
+
img = cv2.imdecode(img_np, imread_flags[flag])
|
130 |
+
if float32:
|
131 |
+
img = img.astype(np.float32) / 255.
|
132 |
+
return img
|
133 |
+
|
134 |
+
|
135 |
+
def imwrite(img, file_path, params=None, auto_mkdir=True):
|
136 |
+
"""Write image to file.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
img (ndarray): Image array to be written.
|
140 |
+
file_path (str): Image file path.
|
141 |
+
params (None or list): Same as opencv's :func:`imwrite` interface.
|
142 |
+
auto_mkdir (bool): If the parent folder of `file_path` does not exist,
|
143 |
+
whether to create it automatically.
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
bool: Successful or not.
|
147 |
+
"""
|
148 |
+
if auto_mkdir:
|
149 |
+
dir_name = os.path.abspath(os.path.dirname(file_path))
|
150 |
+
os.makedirs(dir_name, exist_ok=True)
|
151 |
+
ok = cv2.imwrite(file_path, img, params)
|
152 |
+
if not ok:
|
153 |
+
raise IOError('Failed in writing images.')
|
154 |
+
|
155 |
+
|
156 |
+
def crop_border(imgs, crop_border):
|
157 |
+
"""Crop borders of images.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
|
161 |
+
crop_border (int): Crop border for each end of height and weight.
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
list[ndarray]: Cropped images.
|
165 |
+
"""
|
166 |
+
if crop_border == 0:
|
167 |
+
return imgs
|
168 |
+
else:
|
169 |
+
if isinstance(imgs, list):
|
170 |
+
return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
|
171 |
+
else:
|
172 |
+
return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
|
173 |
+
|
174 |
+
|
175 |
+
def tensor_lab2rgb(labs, illuminant="D65", observer="2"):
|
176 |
+
"""
|
177 |
+
Args:
|
178 |
+
lab : (B, C, H, W)
|
179 |
+
Returns:
|
180 |
+
tuple : (C, H, W)
|
181 |
+
"""
|
182 |
+
illuminants = \
|
183 |
+
{"A": {'2': (1.098466069456375, 1, 0.3558228003436005),
|
184 |
+
'10': (1.111420406956693, 1, 0.3519978321919493)},
|
185 |
+
"D50": {'2': (0.9642119944211994, 1, 0.8251882845188288),
|
186 |
+
'10': (0.9672062750333777, 1, 0.8142801513128616)},
|
187 |
+
"D55": {'2': (0.956797052643698, 1, 0.9214805860173273),
|
188 |
+
'10': (0.9579665682254781, 1, 0.9092525159847462)},
|
189 |
+
"D65": {'2': (0.95047, 1., 1.08883), # This was: `lab_ref_white`
|
190 |
+
'10': (0.94809667673716, 1, 1.0730513595166162)},
|
191 |
+
"D75": {'2': (0.9497220898840717, 1, 1.226393520724154),
|
192 |
+
'10': (0.9441713925645873, 1, 1.2064272211720228)},
|
193 |
+
"E": {'2': (1.0, 1.0, 1.0),
|
194 |
+
'10': (1.0, 1.0, 1.0)}}
|
195 |
+
xyz_from_rgb = np.array([[0.412453, 0.357580, 0.180423], [0.212671, 0.715160, 0.072169],
|
196 |
+
[0.019334, 0.119193, 0.950227]])
|
197 |
+
|
198 |
+
rgb_from_xyz = np.array([[3.240481340, -0.96925495, 0.055646640], [-1.53715152, 1.875990000, -0.20404134],
|
199 |
+
[-0.49853633, 0.041555930, 1.057311070]])
|
200 |
+
B, C, H, W = labs.shape
|
201 |
+
arrs = labs.permute((0, 2, 3, 1)).contiguous() # (B, 3, H, W) -> (B, H, W, 3)
|
202 |
+
L, a, b = arrs[:, :, :, 0:1], arrs[:, :, :, 1:2], arrs[:, :, :, 2:]
|
203 |
+
y = (L + 16.) / 116.
|
204 |
+
x = (a / 500.) + y
|
205 |
+
z = y - (b / 200.)
|
206 |
+
invalid = z.data < 0
|
207 |
+
z[invalid] = 0
|
208 |
+
xyz = torch.cat([x, y, z], dim=3)
|
209 |
+
mask = xyz.data > 0.2068966
|
210 |
+
mask_xyz = xyz.clone()
|
211 |
+
mask_xyz[mask] = torch.pow(xyz[mask], 3.0)
|
212 |
+
mask_xyz[~mask] = (xyz[~mask] - 16.0 / 116.) / 7.787
|
213 |
+
xyz_ref_white = illuminants[illuminant][observer]
|
214 |
+
for i in range(C):
|
215 |
+
mask_xyz[:, :, :, i] = mask_xyz[:, :, :, i] * xyz_ref_white[i]
|
216 |
+
|
217 |
+
rgb_trans = torch.mm(mask_xyz.view(-1, 3), torch.from_numpy(rgb_from_xyz).type_as(xyz)).view(B, H, W, C)
|
218 |
+
rgb = rgb_trans.permute((0, 3, 1, 2)).contiguous()
|
219 |
+
mask = rgb.data > 0.0031308
|
220 |
+
mask_rgb = rgb.clone()
|
221 |
+
mask_rgb[mask] = 1.055 * torch.pow(rgb[mask], 1 / 2.4) - 0.055
|
222 |
+
mask_rgb[~mask] = rgb[~mask] * 12.92
|
223 |
+
neg_mask = mask_rgb.data < 0
|
224 |
+
large_mask = mask_rgb.data > 1
|
225 |
+
mask_rgb[neg_mask] = 0
|
226 |
+
mask_rgb[large_mask] = 1
|
227 |
+
return mask_rgb
|
basicsr/utils/lmdb_util.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import lmdb
|
3 |
+
import sys
|
4 |
+
from multiprocessing import Pool
|
5 |
+
from os import path as osp
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
|
9 |
+
def make_lmdb_from_imgs(data_path,
|
10 |
+
lmdb_path,
|
11 |
+
img_path_list,
|
12 |
+
keys,
|
13 |
+
batch=5000,
|
14 |
+
compress_level=1,
|
15 |
+
multiprocessing_read=False,
|
16 |
+
n_thread=40,
|
17 |
+
map_size=None):
|
18 |
+
"""Make lmdb from images.
|
19 |
+
|
20 |
+
Contents of lmdb. The file structure is:
|
21 |
+
example.lmdb
|
22 |
+
├── data.mdb
|
23 |
+
├── lock.mdb
|
24 |
+
├── meta_info.txt
|
25 |
+
|
26 |
+
The data.mdb and lock.mdb are standard lmdb files and you can refer to
|
27 |
+
https://lmdb.readthedocs.io/en/release/ for more details.
|
28 |
+
|
29 |
+
The meta_info.txt is a specified txt file to record the meta information
|
30 |
+
of our datasets. It will be automatically created when preparing
|
31 |
+
datasets by our provided dataset tools.
|
32 |
+
Each line in the txt file records 1)image name (with extension),
|
33 |
+
2)image shape, and 3)compression level, separated by a white space.
|
34 |
+
|
35 |
+
For example, the meta information could be:
|
36 |
+
`000_00000000.png (720,1280,3) 1`, which means:
|
37 |
+
1) image name (with extension): 000_00000000.png;
|
38 |
+
2) image shape: (720,1280,3);
|
39 |
+
3) compression level: 1
|
40 |
+
|
41 |
+
We use the image name without extension as the lmdb key.
|
42 |
+
|
43 |
+
If `multiprocessing_read` is True, it will read all the images to memory
|
44 |
+
using multiprocessing. Thus, your server needs to have enough memory.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
data_path (str): Data path for reading images.
|
48 |
+
lmdb_path (str): Lmdb save path.
|
49 |
+
img_path_list (str): Image path list.
|
50 |
+
keys (str): Used for lmdb keys.
|
51 |
+
batch (int): After processing batch images, lmdb commits.
|
52 |
+
Default: 5000.
|
53 |
+
compress_level (int): Compress level when encoding images. Default: 1.
|
54 |
+
multiprocessing_read (bool): Whether use multiprocessing to read all
|
55 |
+
the images to memory. Default: False.
|
56 |
+
n_thread (int): For multiprocessing.
|
57 |
+
map_size (int | None): Map size for lmdb env. If None, use the
|
58 |
+
estimated size from images. Default: None
|
59 |
+
"""
|
60 |
+
|
61 |
+
assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, '
|
62 |
+
f'but got {len(img_path_list)} and {len(keys)}')
|
63 |
+
print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
|
64 |
+
print(f'Totoal images: {len(img_path_list)}')
|
65 |
+
if not lmdb_path.endswith('.lmdb'):
|
66 |
+
raise ValueError("lmdb_path must end with '.lmdb'.")
|
67 |
+
if osp.exists(lmdb_path):
|
68 |
+
print(f'Folder {lmdb_path} already exists. Exit.')
|
69 |
+
sys.exit(1)
|
70 |
+
|
71 |
+
if multiprocessing_read:
|
72 |
+
# read all the images to memory (multiprocessing)
|
73 |
+
dataset = {} # use dict to keep the order for multiprocessing
|
74 |
+
shapes = {}
|
75 |
+
print(f'Read images with multiprocessing, #thread: {n_thread} ...')
|
76 |
+
pbar = tqdm(total=len(img_path_list), unit='image')
|
77 |
+
|
78 |
+
def callback(arg):
|
79 |
+
"""get the image data and update pbar."""
|
80 |
+
key, dataset[key], shapes[key] = arg
|
81 |
+
pbar.update(1)
|
82 |
+
pbar.set_description(f'Read {key}')
|
83 |
+
|
84 |
+
pool = Pool(n_thread)
|
85 |
+
for path, key in zip(img_path_list, keys):
|
86 |
+
pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback)
|
87 |
+
pool.close()
|
88 |
+
pool.join()
|
89 |
+
pbar.close()
|
90 |
+
print(f'Finish reading {len(img_path_list)} images.')
|
91 |
+
|
92 |
+
# create lmdb environment
|
93 |
+
if map_size is None:
|
94 |
+
# obtain data size for one image
|
95 |
+
img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
|
96 |
+
_, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
|
97 |
+
data_size_per_img = img_byte.nbytes
|
98 |
+
print('Data size per image is: ', data_size_per_img)
|
99 |
+
data_size = data_size_per_img * len(img_path_list)
|
100 |
+
map_size = data_size * 10
|
101 |
+
|
102 |
+
env = lmdb.open(lmdb_path, map_size=map_size)
|
103 |
+
|
104 |
+
# write data to lmdb
|
105 |
+
pbar = tqdm(total=len(img_path_list), unit='chunk')
|
106 |
+
txn = env.begin(write=True)
|
107 |
+
txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
|
108 |
+
for idx, (path, key) in enumerate(zip(img_path_list, keys)):
|
109 |
+
pbar.update(1)
|
110 |
+
pbar.set_description(f'Write {key}')
|
111 |
+
key_byte = key.encode('ascii')
|
112 |
+
if multiprocessing_read:
|
113 |
+
img_byte = dataset[key]
|
114 |
+
h, w, c = shapes[key]
|
115 |
+
else:
|
116 |
+
_, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level)
|
117 |
+
h, w, c = img_shape
|
118 |
+
|
119 |
+
txn.put(key_byte, img_byte)
|
120 |
+
# write meta information
|
121 |
+
txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
|
122 |
+
if idx % batch == 0:
|
123 |
+
txn.commit()
|
124 |
+
txn = env.begin(write=True)
|
125 |
+
pbar.close()
|
126 |
+
txn.commit()
|
127 |
+
env.close()
|
128 |
+
txt_file.close()
|
129 |
+
print('\nFinish writing lmdb.')
|
130 |
+
|
131 |
+
|
132 |
+
def read_img_worker(path, key, compress_level):
|
133 |
+
"""Read image worker.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
path (str): Image path.
|
137 |
+
key (str): Image key.
|
138 |
+
compress_level (int): Compress level when encoding images.
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
str: Image key.
|
142 |
+
byte: Image byte.
|
143 |
+
tuple[int]: Image shape.
|
144 |
+
"""
|
145 |
+
|
146 |
+
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
147 |
+
if img.ndim == 2:
|
148 |
+
h, w = img.shape
|
149 |
+
c = 1
|
150 |
+
else:
|
151 |
+
h, w, c = img.shape
|
152 |
+
_, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
|
153 |
+
return (key, img_byte, (h, w, c))
|
154 |
+
|
155 |
+
|
156 |
+
class LmdbMaker():
|
157 |
+
"""LMDB Maker.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
lmdb_path (str): Lmdb save path.
|
161 |
+
map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
|
162 |
+
batch (int): After processing batch images, lmdb commits.
|
163 |
+
Default: 5000.
|
164 |
+
compress_level (int): Compress level when encoding images. Default: 1.
|
165 |
+
"""
|
166 |
+
|
167 |
+
def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1):
|
168 |
+
if not lmdb_path.endswith('.lmdb'):
|
169 |
+
raise ValueError("lmdb_path must end with '.lmdb'.")
|
170 |
+
if osp.exists(lmdb_path):
|
171 |
+
print(f'Folder {lmdb_path} already exists. Exit.')
|
172 |
+
sys.exit(1)
|
173 |
+
|
174 |
+
self.lmdb_path = lmdb_path
|
175 |
+
self.batch = batch
|
176 |
+
self.compress_level = compress_level
|
177 |
+
self.env = lmdb.open(lmdb_path, map_size=map_size)
|
178 |
+
self.txn = self.env.begin(write=True)
|
179 |
+
self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
|
180 |
+
self.counter = 0
|
181 |
+
|
182 |
+
def put(self, img_byte, key, img_shape):
|
183 |
+
self.counter += 1
|
184 |
+
key_byte = key.encode('ascii')
|
185 |
+
self.txn.put(key_byte, img_byte)
|
186 |
+
# write meta information
|
187 |
+
h, w, c = img_shape
|
188 |
+
self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
|
189 |
+
if self.counter % self.batch == 0:
|
190 |
+
self.txn.commit()
|
191 |
+
self.txn = self.env.begin(write=True)
|
192 |
+
|
193 |
+
def close(self):
|
194 |
+
self.txn.commit()
|
195 |
+
self.env.close()
|
196 |
+
self.txt_file.close()
|
basicsr/utils/logger.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import logging
|
3 |
+
import time
|
4 |
+
|
5 |
+
from .dist_util import get_dist_info, master_only
|
6 |
+
|
7 |
+
initialized_logger = {}
|
8 |
+
|
9 |
+
|
10 |
+
class AvgTimer():
|
11 |
+
|
12 |
+
def __init__(self, window=200):
|
13 |
+
self.window = window # average window
|
14 |
+
self.current_time = 0
|
15 |
+
self.total_time = 0
|
16 |
+
self.count = 0
|
17 |
+
self.avg_time = 0
|
18 |
+
self.start()
|
19 |
+
|
20 |
+
def start(self):
|
21 |
+
self.start_time = time.time()
|
22 |
+
|
23 |
+
def record(self):
|
24 |
+
self.count += 1
|
25 |
+
self.current_time = time.time() - self.start_time
|
26 |
+
self.total_time += self.current_time
|
27 |
+
# calculate average time
|
28 |
+
self.avg_time = self.total_time / self.count
|
29 |
+
# reset
|
30 |
+
if self.count > self.window:
|
31 |
+
self.count = 0
|
32 |
+
self.total_time = 0
|
33 |
+
|
34 |
+
def get_current_time(self):
|
35 |
+
return self.current_time
|
36 |
+
|
37 |
+
def get_avg_time(self):
|
38 |
+
return self.avg_time
|
39 |
+
|
40 |
+
|
41 |
+
class MessageLogger():
|
42 |
+
"""Message logger for printing.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
opt (dict): Config. It contains the following keys:
|
46 |
+
name (str): Exp name.
|
47 |
+
logger (dict): Contains 'print_freq' (str) for logger interval.
|
48 |
+
train (dict): Contains 'total_iter' (int) for total iters.
|
49 |
+
use_tb_logger (bool): Use tensorboard logger.
|
50 |
+
start_iter (int): Start iter. Default: 1.
|
51 |
+
tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
|
52 |
+
"""
|
53 |
+
|
54 |
+
def __init__(self, opt, start_iter=1, tb_logger=None):
|
55 |
+
self.exp_name = opt['name']
|
56 |
+
self.interval = opt['logger']['print_freq']
|
57 |
+
self.start_iter = start_iter
|
58 |
+
self.max_iters = opt['train']['total_iter']
|
59 |
+
self.use_tb_logger = opt['logger']['use_tb_logger']
|
60 |
+
self.tb_logger = tb_logger
|
61 |
+
self.start_time = time.time()
|
62 |
+
self.logger = get_root_logger()
|
63 |
+
|
64 |
+
def reset_start_time(self):
|
65 |
+
self.start_time = time.time()
|
66 |
+
|
67 |
+
@master_only
|
68 |
+
def __call__(self, log_vars):
|
69 |
+
"""Format logging message.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
log_vars (dict): It contains the following keys:
|
73 |
+
epoch (int): Epoch number.
|
74 |
+
iter (int): Current iter.
|
75 |
+
lrs (list): List for learning rates.
|
76 |
+
|
77 |
+
time (float): Iter time.
|
78 |
+
data_time (float): Data time for each iter.
|
79 |
+
"""
|
80 |
+
# epoch, iter, learning rates
|
81 |
+
epoch = log_vars.pop('epoch')
|
82 |
+
current_iter = log_vars.pop('iter')
|
83 |
+
lrs = log_vars.pop('lrs')
|
84 |
+
|
85 |
+
message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, iter:{current_iter:8,d}, lr:(')
|
86 |
+
for v in lrs:
|
87 |
+
message += f'{v:.3e},'
|
88 |
+
message += ')] '
|
89 |
+
|
90 |
+
# time and estimated time
|
91 |
+
if 'time' in log_vars.keys():
|
92 |
+
iter_time = log_vars.pop('time')
|
93 |
+
data_time = log_vars.pop('data_time')
|
94 |
+
|
95 |
+
total_time = time.time() - self.start_time
|
96 |
+
time_sec_avg = total_time / (current_iter - self.start_iter + 1)
|
97 |
+
eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
|
98 |
+
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
|
99 |
+
message += f'[eta: {eta_str}, '
|
100 |
+
message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
|
101 |
+
|
102 |
+
# other items, especially losses
|
103 |
+
for k, v in log_vars.items():
|
104 |
+
message += f'{k}: {v:.4e} '
|
105 |
+
# tensorboard logger
|
106 |
+
if self.use_tb_logger and 'debug' not in self.exp_name:
|
107 |
+
if k.startswith('l_'):
|
108 |
+
self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
|
109 |
+
else:
|
110 |
+
self.tb_logger.add_scalar(k, v, current_iter)
|
111 |
+
self.logger.info(message)
|
112 |
+
|
113 |
+
|
114 |
+
@master_only
|
115 |
+
def init_tb_logger(log_dir):
|
116 |
+
from torch.utils.tensorboard import SummaryWriter
|
117 |
+
tb_logger = SummaryWriter(log_dir=log_dir)
|
118 |
+
return tb_logger
|
119 |
+
|
120 |
+
|
121 |
+
@master_only
|
122 |
+
def init_wandb_logger(opt):
|
123 |
+
"""We now only use wandb to sync tensorboard log."""
|
124 |
+
import wandb
|
125 |
+
logger = get_root_logger()
|
126 |
+
|
127 |
+
project = opt['logger']['wandb']['project']
|
128 |
+
resume_id = opt['logger']['wandb'].get('resume_id')
|
129 |
+
if resume_id:
|
130 |
+
wandb_id = resume_id
|
131 |
+
resume = 'allow'
|
132 |
+
logger.warning(f'Resume wandb logger with id={wandb_id}.')
|
133 |
+
else:
|
134 |
+
wandb_id = wandb.util.generate_id()
|
135 |
+
resume = 'never'
|
136 |
+
|
137 |
+
wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True)
|
138 |
+
|
139 |
+
logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')
|
140 |
+
|
141 |
+
|
142 |
+
def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
|
143 |
+
"""Get the root logger.
|
144 |
+
|
145 |
+
The logger will be initialized if it has not been initialized. By default a
|
146 |
+
StreamHandler will be added. If `log_file` is specified, a FileHandler will
|
147 |
+
also be added.
|
148 |
+
|
149 |
+
Args:
|
150 |
+
logger_name (str): root logger name. Default: 'basicsr'.
|
151 |
+
log_file (str | None): The log filename. If specified, a FileHandler
|
152 |
+
will be added to the root logger.
|
153 |
+
log_level (int): The root logger level. Note that only the process of
|
154 |
+
rank 0 is affected, while other processes will set the level to
|
155 |
+
"Error" and be silent most of the time.
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
logging.Logger: The root logger.
|
159 |
+
"""
|
160 |
+
logger = logging.getLogger(logger_name)
|
161 |
+
# if the logger has been initialized, just return it
|
162 |
+
if logger_name in initialized_logger:
|
163 |
+
return logger
|
164 |
+
|
165 |
+
format_str = '%(asctime)s %(levelname)s: %(message)s'
|
166 |
+
stream_handler = logging.StreamHandler()
|
167 |
+
stream_handler.setFormatter(logging.Formatter(format_str))
|
168 |
+
logger.addHandler(stream_handler)
|
169 |
+
logger.propagate = False
|
170 |
+
rank, _ = get_dist_info()
|
171 |
+
if rank != 0:
|
172 |
+
logger.setLevel('ERROR')
|
173 |
+
elif log_file is not None:
|
174 |
+
logger.setLevel(log_level)
|
175 |
+
# add file handler
|
176 |
+
file_handler = logging.FileHandler(log_file, 'w')
|
177 |
+
file_handler.setFormatter(logging.Formatter(format_str))
|
178 |
+
file_handler.setLevel(log_level)
|
179 |
+
logger.addHandler(file_handler)
|
180 |
+
initialized_logger[logger_name] = True
|
181 |
+
return logger
|
182 |
+
|
183 |
+
|
184 |
+
def get_env_info():
|
185 |
+
"""Get environment information.
|
186 |
+
|
187 |
+
Currently, only log the software version.
|
188 |
+
"""
|
189 |
+
import torch
|
190 |
+
import torchvision
|
191 |
+
|
192 |
+
from basicsr.version import __version__
|
193 |
+
msg = r"""
|
194 |
+
____ _ _____ ____
|
195 |
+
/ __ ) ____ _ _____ (_)_____/ ___/ / __ \
|
196 |
+
/ __ |/ __ `// ___// // ___/\__ \ / /_/ /
|
197 |
+
/ /_/ // /_/ /(__ )/ // /__ ___/ // _, _/
|
198 |
+
/_____/ \__,_//____//_/ \___//____//_/ |_|
|
199 |
+
______ __ __ __ __
|
200 |
+
/ ____/____ ____ ____/ / / / __ __ _____ / /__ / /
|
201 |
+
/ / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / /
|
202 |
+
/ /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/
|
203 |
+
\____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_)
|
204 |
+
"""
|
205 |
+
msg += ('\nVersion Information: '
|
206 |
+
f'\n\tBasicSR: {__version__}'
|
207 |
+
f'\n\tPyTorch: {torch.__version__}'
|
208 |
+
f'\n\tTorchVision: {torchvision.__version__}')
|
209 |
+
return msg
|
basicsr/utils/matlab_functions.py
ADDED
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def cubic(x):
|
7 |
+
"""cubic function used for calculate_weights_indices."""
|
8 |
+
absx = torch.abs(x)
|
9 |
+
absx2 = absx**2
|
10 |
+
absx3 = absx**3
|
11 |
+
return (1.5 * absx3 - 2.5 * absx2 + 1) * (
|
12 |
+
(absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) *
|
13 |
+
(absx <= 2)).type_as(absx))
|
14 |
+
|
15 |
+
|
16 |
+
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
|
17 |
+
"""Calculate weights and indices, used for imresize function.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
in_length (int): Input length.
|
21 |
+
out_length (int): Output length.
|
22 |
+
scale (float): Scale factor.
|
23 |
+
kernel_width (int): Kernel width.
|
24 |
+
antialisaing (bool): Whether to apply anti-aliasing when downsampling.
|
25 |
+
"""
|
26 |
+
|
27 |
+
if (scale < 1) and antialiasing:
|
28 |
+
# Use a modified kernel (larger kernel width) to simultaneously
|
29 |
+
# interpolate and antialias
|
30 |
+
kernel_width = kernel_width / scale
|
31 |
+
|
32 |
+
# Output-space coordinates
|
33 |
+
x = torch.linspace(1, out_length, out_length)
|
34 |
+
|
35 |
+
# Input-space coordinates. Calculate the inverse mapping such that 0.5
|
36 |
+
# in output space maps to 0.5 in input space, and 0.5 + scale in output
|
37 |
+
# space maps to 1.5 in input space.
|
38 |
+
u = x / scale + 0.5 * (1 - 1 / scale)
|
39 |
+
|
40 |
+
# What is the left-most pixel that can be involved in the computation?
|
41 |
+
left = torch.floor(u - kernel_width / 2)
|
42 |
+
|
43 |
+
# What is the maximum number of pixels that can be involved in the
|
44 |
+
# computation? Note: it's OK to use an extra pixel here; if the
|
45 |
+
# corresponding weights are all zero, it will be eliminated at the end
|
46 |
+
# of this function.
|
47 |
+
p = math.ceil(kernel_width) + 2
|
48 |
+
|
49 |
+
# The indices of the input pixels involved in computing the k-th output
|
50 |
+
# pixel are in row k of the indices matrix.
|
51 |
+
indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
|
52 |
+
out_length, p)
|
53 |
+
|
54 |
+
# The weights used to compute the k-th output pixel are in row k of the
|
55 |
+
# weights matrix.
|
56 |
+
distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
|
57 |
+
|
58 |
+
# apply cubic kernel
|
59 |
+
if (scale < 1) and antialiasing:
|
60 |
+
weights = scale * cubic(distance_to_center * scale)
|
61 |
+
else:
|
62 |
+
weights = cubic(distance_to_center)
|
63 |
+
|
64 |
+
# Normalize the weights matrix so that each row sums to 1.
|
65 |
+
weights_sum = torch.sum(weights, 1).view(out_length, 1)
|
66 |
+
weights = weights / weights_sum.expand(out_length, p)
|
67 |
+
|
68 |
+
# If a column in weights is all zero, get rid of it. only consider the
|
69 |
+
# first and last column.
|
70 |
+
weights_zero_tmp = torch.sum((weights == 0), 0)
|
71 |
+
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
|
72 |
+
indices = indices.narrow(1, 1, p - 2)
|
73 |
+
weights = weights.narrow(1, 1, p - 2)
|
74 |
+
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
|
75 |
+
indices = indices.narrow(1, 0, p - 2)
|
76 |
+
weights = weights.narrow(1, 0, p - 2)
|
77 |
+
weights = weights.contiguous()
|
78 |
+
indices = indices.contiguous()
|
79 |
+
sym_len_s = -indices.min() + 1
|
80 |
+
sym_len_e = indices.max() - in_length
|
81 |
+
indices = indices + sym_len_s - 1
|
82 |
+
return weights, indices, int(sym_len_s), int(sym_len_e)
|
83 |
+
|
84 |
+
|
85 |
+
@torch.no_grad()
|
86 |
+
def imresize(img, scale, antialiasing=True):
|
87 |
+
"""imresize function same as MATLAB.
|
88 |
+
|
89 |
+
It now only supports bicubic.
|
90 |
+
The same scale applies for both height and width.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
img (Tensor | Numpy array):
|
94 |
+
Tensor: Input image with shape (c, h, w), [0, 1] range.
|
95 |
+
Numpy: Input image with shape (h, w, c), [0, 1] range.
|
96 |
+
scale (float): Scale factor. The same scale applies for both height
|
97 |
+
and width.
|
98 |
+
antialisaing (bool): Whether to apply anti-aliasing when downsampling.
|
99 |
+
Default: True.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
|
103 |
+
"""
|
104 |
+
squeeze_flag = False
|
105 |
+
if type(img).__module__ == np.__name__: # numpy type
|
106 |
+
numpy_type = True
|
107 |
+
if img.ndim == 2:
|
108 |
+
img = img[:, :, None]
|
109 |
+
squeeze_flag = True
|
110 |
+
img = torch.from_numpy(img.transpose(2, 0, 1)).float()
|
111 |
+
else:
|
112 |
+
numpy_type = False
|
113 |
+
if img.ndim == 2:
|
114 |
+
img = img.unsqueeze(0)
|
115 |
+
squeeze_flag = True
|
116 |
+
|
117 |
+
in_c, in_h, in_w = img.size()
|
118 |
+
out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale)
|
119 |
+
kernel_width = 4
|
120 |
+
kernel = 'cubic'
|
121 |
+
|
122 |
+
# get weights and indices
|
123 |
+
weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width,
|
124 |
+
antialiasing)
|
125 |
+
weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width,
|
126 |
+
antialiasing)
|
127 |
+
# process H dimension
|
128 |
+
# symmetric copying
|
129 |
+
img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
|
130 |
+
img_aug.narrow(1, sym_len_hs, in_h).copy_(img)
|
131 |
+
|
132 |
+
sym_patch = img[:, :sym_len_hs, :]
|
133 |
+
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
134 |
+
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
135 |
+
img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
|
136 |
+
|
137 |
+
sym_patch = img[:, -sym_len_he:, :]
|
138 |
+
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
139 |
+
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
140 |
+
img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
|
141 |
+
|
142 |
+
out_1 = torch.FloatTensor(in_c, out_h, in_w)
|
143 |
+
kernel_width = weights_h.size(1)
|
144 |
+
for i in range(out_h):
|
145 |
+
idx = int(indices_h[i][0])
|
146 |
+
for j in range(in_c):
|
147 |
+
out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
|
148 |
+
|
149 |
+
# process W dimension
|
150 |
+
# symmetric copying
|
151 |
+
out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
|
152 |
+
out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
|
153 |
+
|
154 |
+
sym_patch = out_1[:, :, :sym_len_ws]
|
155 |
+
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
156 |
+
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
157 |
+
out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
|
158 |
+
|
159 |
+
sym_patch = out_1[:, :, -sym_len_we:]
|
160 |
+
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
161 |
+
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
162 |
+
out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
|
163 |
+
|
164 |
+
out_2 = torch.FloatTensor(in_c, out_h, out_w)
|
165 |
+
kernel_width = weights_w.size(1)
|
166 |
+
for i in range(out_w):
|
167 |
+
idx = int(indices_w[i][0])
|
168 |
+
for j in range(in_c):
|
169 |
+
out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
|
170 |
+
|
171 |
+
if squeeze_flag:
|
172 |
+
out_2 = out_2.squeeze(0)
|
173 |
+
if numpy_type:
|
174 |
+
out_2 = out_2.numpy()
|
175 |
+
if not squeeze_flag:
|
176 |
+
out_2 = out_2.transpose(1, 2, 0)
|
177 |
+
|
178 |
+
return out_2
|
179 |
+
|
180 |
+
|
181 |
+
def rgb2ycbcr(img, y_only=False):
|
182 |
+
"""Convert a RGB image to YCbCr image.
|
183 |
+
|
184 |
+
This function produces the same results as Matlab's `rgb2ycbcr` function.
|
185 |
+
It implements the ITU-R BT.601 conversion for standard-definition
|
186 |
+
television. See more details in
|
187 |
+
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
188 |
+
|
189 |
+
It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
|
190 |
+
In OpenCV, it implements a JPEG conversion. See more details in
|
191 |
+
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
img (ndarray): The input image. It accepts:
|
195 |
+
1. np.uint8 type with range [0, 255];
|
196 |
+
2. np.float32 type with range [0, 1].
|
197 |
+
y_only (bool): Whether to only return Y channel. Default: False.
|
198 |
+
|
199 |
+
Returns:
|
200 |
+
ndarray: The converted YCbCr image. The output image has the same type
|
201 |
+
and range as input image.
|
202 |
+
"""
|
203 |
+
img_type = img.dtype
|
204 |
+
img = _convert_input_type_range(img)
|
205 |
+
if y_only:
|
206 |
+
out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
|
207 |
+
else:
|
208 |
+
out_img = np.matmul(
|
209 |
+
img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128]
|
210 |
+
out_img = _convert_output_type_range(out_img, img_type)
|
211 |
+
return out_img
|
212 |
+
|
213 |
+
|
214 |
+
def bgr2ycbcr(img, y_only=False):
|
215 |
+
"""Convert a BGR image to YCbCr image.
|
216 |
+
|
217 |
+
The bgr version of rgb2ycbcr.
|
218 |
+
It implements the ITU-R BT.601 conversion for standard-definition
|
219 |
+
television. See more details in
|
220 |
+
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
221 |
+
|
222 |
+
It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
|
223 |
+
In OpenCV, it implements a JPEG conversion. See more details in
|
224 |
+
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
225 |
+
|
226 |
+
Args:
|
227 |
+
img (ndarray): The input image. It accepts:
|
228 |
+
1. np.uint8 type with range [0, 255];
|
229 |
+
2. np.float32 type with range [0, 1].
|
230 |
+
y_only (bool): Whether to only return Y channel. Default: False.
|
231 |
+
|
232 |
+
Returns:
|
233 |
+
ndarray: The converted YCbCr image. The output image has the same type
|
234 |
+
and range as input image.
|
235 |
+
"""
|
236 |
+
img_type = img.dtype
|
237 |
+
img = _convert_input_type_range(img)
|
238 |
+
if y_only:
|
239 |
+
out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
|
240 |
+
else:
|
241 |
+
out_img = np.matmul(
|
242 |
+
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
|
243 |
+
out_img = _convert_output_type_range(out_img, img_type)
|
244 |
+
return out_img
|
245 |
+
|
246 |
+
|
247 |
+
def ycbcr2rgb(img):
|
248 |
+
"""Convert a YCbCr image to RGB image.
|
249 |
+
|
250 |
+
This function produces the same results as Matlab's ycbcr2rgb function.
|
251 |
+
It implements the ITU-R BT.601 conversion for standard-definition
|
252 |
+
television. See more details in
|
253 |
+
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
254 |
+
|
255 |
+
It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
|
256 |
+
In OpenCV, it implements a JPEG conversion. See more details in
|
257 |
+
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
258 |
+
|
259 |
+
Args:
|
260 |
+
img (ndarray): The input image. It accepts:
|
261 |
+
1. np.uint8 type with range [0, 255];
|
262 |
+
2. np.float32 type with range [0, 1].
|
263 |
+
|
264 |
+
Returns:
|
265 |
+
ndarray: The converted RGB image. The output image has the same type
|
266 |
+
and range as input image.
|
267 |
+
"""
|
268 |
+
img_type = img.dtype
|
269 |
+
img = _convert_input_type_range(img) * 255
|
270 |
+
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
|
271 |
+
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126
|
272 |
+
out_img = _convert_output_type_range(out_img, img_type)
|
273 |
+
return out_img
|
274 |
+
|
275 |
+
|
276 |
+
def ycbcr2bgr(img):
|
277 |
+
"""Convert a YCbCr image to BGR image.
|
278 |
+
|
279 |
+
The bgr version of ycbcr2rgb.
|
280 |
+
It implements the ITU-R BT.601 conversion for standard-definition
|
281 |
+
television. See more details in
|
282 |
+
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
283 |
+
|
284 |
+
It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
|
285 |
+
In OpenCV, it implements a JPEG conversion. See more details in
|
286 |
+
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
287 |
+
|
288 |
+
Args:
|
289 |
+
img (ndarray): The input image. It accepts:
|
290 |
+
1. np.uint8 type with range [0, 255];
|
291 |
+
2. np.float32 type with range [0, 1].
|
292 |
+
|
293 |
+
Returns:
|
294 |
+
ndarray: The converted BGR image. The output image has the same type
|
295 |
+
and range as input image.
|
296 |
+
"""
|
297 |
+
img_type = img.dtype
|
298 |
+
img = _convert_input_type_range(img) * 255
|
299 |
+
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0],
|
300 |
+
[0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126
|
301 |
+
out_img = _convert_output_type_range(out_img, img_type)
|
302 |
+
return out_img
|
303 |
+
|
304 |
+
|
305 |
+
def _convert_input_type_range(img):
|
306 |
+
"""Convert the type and range of the input image.
|
307 |
+
|
308 |
+
It converts the input image to np.float32 type and range of [0, 1].
|
309 |
+
It is mainly used for pre-processing the input image in colorspace
|
310 |
+
conversion functions such as rgb2ycbcr and ycbcr2rgb.
|
311 |
+
|
312 |
+
Args:
|
313 |
+
img (ndarray): The input image. It accepts:
|
314 |
+
1. np.uint8 type with range [0, 255];
|
315 |
+
2. np.float32 type with range [0, 1].
|
316 |
+
|
317 |
+
Returns:
|
318 |
+
(ndarray): The converted image with type of np.float32 and range of
|
319 |
+
[0, 1].
|
320 |
+
"""
|
321 |
+
img_type = img.dtype
|
322 |
+
img = img.astype(np.float32)
|
323 |
+
if img_type == np.float32:
|
324 |
+
pass
|
325 |
+
elif img_type == np.uint8:
|
326 |
+
img /= 255.
|
327 |
+
else:
|
328 |
+
raise TypeError(f'The img type should be np.float32 or np.uint8, but got {img_type}')
|
329 |
+
return img
|
330 |
+
|
331 |
+
|
332 |
+
def _convert_output_type_range(img, dst_type):
|
333 |
+
"""Convert the type and range of the image according to dst_type.
|
334 |
+
|
335 |
+
It converts the image to desired type and range. If `dst_type` is np.uint8,
|
336 |
+
images will be converted to np.uint8 type with range [0, 255]. If
|
337 |
+
`dst_type` is np.float32, it converts the image to np.float32 type with
|
338 |
+
range [0, 1].
|
339 |
+
It is mainly used for post-processing images in colorspace conversion
|
340 |
+
functions such as rgb2ycbcr and ycbcr2rgb.
|
341 |
+
|
342 |
+
Args:
|
343 |
+
img (ndarray): The image to be converted with np.float32 type and
|
344 |
+
range [0, 255].
|
345 |
+
dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
|
346 |
+
converts the image to np.uint8 type with range [0, 255]. If
|
347 |
+
dst_type is np.float32, it converts the image to np.float32 type
|
348 |
+
with range [0, 1].
|
349 |
+
|
350 |
+
Returns:
|
351 |
+
(ndarray): The converted image with desired type and range.
|
352 |
+
"""
|
353 |
+
if dst_type not in (np.uint8, np.float32):
|
354 |
+
raise TypeError(f'The dst_type should be np.float32 or np.uint8, but got {dst_type}')
|
355 |
+
if dst_type == np.uint8:
|
356 |
+
img = img.round()
|
357 |
+
else:
|
358 |
+
img /= 255.
|
359 |
+
return img.astype(dst_type)
|