Sulio commited on
Commit
00e6746
·
verified ·
1 Parent(s): 624281e

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +139 -0
  2. LICENSE +201 -0
  3. MODEL_ZOO.md +13 -0
  4. README.md +201 -7
  5. VERSION +1 -0
  6. basicsr/__init__.py +15 -0
  7. basicsr/archs/__init__.py +25 -0
  8. basicsr/archs/ddcolor_arch.py +385 -0
  9. basicsr/archs/ddcolor_arch_utils/__int__.py +0 -0
  10. basicsr/archs/ddcolor_arch_utils/convnext.py +155 -0
  11. basicsr/archs/ddcolor_arch_utils/position_encoding.py +52 -0
  12. basicsr/archs/ddcolor_arch_utils/transformer.py +368 -0
  13. basicsr/archs/ddcolor_arch_utils/transformer_utils.py +192 -0
  14. basicsr/archs/ddcolor_arch_utils/unet.py +208 -0
  15. basicsr/archs/ddcolor_arch_utils/util.py +63 -0
  16. basicsr/archs/discriminator_arch.py +28 -0
  17. basicsr/archs/vgg_arch.py +165 -0
  18. basicsr/data/__init__.py +101 -0
  19. basicsr/data/data_sampler.py +48 -0
  20. basicsr/data/data_util.py +313 -0
  21. basicsr/data/fmix.py +206 -0
  22. basicsr/data/lab_dataset.py +159 -0
  23. basicsr/data/prefetch_dataloader.py +125 -0
  24. basicsr/data/transforms.py +192 -0
  25. basicsr/losses/__init__.py +26 -0
  26. basicsr/losses/loss_util.py +95 -0
  27. basicsr/losses/losses.py +551 -0
  28. basicsr/metrics/__init__.py +20 -0
  29. basicsr/metrics/colorfulness.py +17 -0
  30. basicsr/metrics/custom_fid.py +260 -0
  31. basicsr/metrics/metric_util.py +45 -0
  32. basicsr/metrics/psnr_ssim.py +128 -0
  33. basicsr/models/__init__.py +30 -0
  34. basicsr/models/base_model.py +382 -0
  35. basicsr/models/color_model.py +369 -0
  36. basicsr/models/lr_scheduler.py +96 -0
  37. basicsr/train.py +224 -0
  38. basicsr/utils/__init__.py +37 -0
  39. basicsr/utils/color_enhance.py +9 -0
  40. basicsr/utils/diffjpeg.py +515 -0
  41. basicsr/utils/dist_util.py +82 -0
  42. basicsr/utils/download_util.py +64 -0
  43. basicsr/utils/face_util.py +192 -0
  44. basicsr/utils/file_client.py +167 -0
  45. basicsr/utils/flow_util.py +170 -0
  46. basicsr/utils/img_process_util.py +83 -0
  47. basicsr/utils/img_util.py +227 -0
  48. basicsr/utils/lmdb_util.py +196 -0
  49. basicsr/utils/logger.py +209 -0
  50. basicsr/utils/matlab_functions.py +359 -0
.gitignore ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ignored folders
2
+ datasets/*
3
+ experiments/*
4
+ results/*
5
+ tb_logger/*
6
+ wandb/*
7
+ tmp/*
8
+
9
+ *.DS_Store
10
+ .idea
11
+ .vscode
12
+ .github
13
+
14
+ .onnx
15
+
16
+ # ignored files
17
+ version.py
18
+
19
+ # ignored files with suffix
20
+ *.html
21
+ *.png
22
+ *.jpeg
23
+ *.jpg
24
+ *.gif
25
+ *.pth
26
+ *.zip
27
+ *.npy
28
+
29
+ # template
30
+
31
+ # Byte-compiled / optimized / DLL files
32
+ __pycache__/
33
+ *.py[cod]
34
+ *$py.class
35
+
36
+ # C extensions
37
+ *.so
38
+
39
+ # Distribution / packaging
40
+ .Python
41
+ build/
42
+ develop-eggs/
43
+ dist/
44
+ downloads/
45
+ eggs/
46
+ .eggs/
47
+ lib/
48
+ lib64/
49
+ parts/
50
+ sdist/
51
+ var/
52
+ wheels/
53
+ *.egg-info/
54
+ .installed.cfg
55
+ *.egg
56
+ MANIFEST
57
+
58
+ # PyInstaller
59
+ # Usually these files are written by a python script from a template
60
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
61
+ *.manifest
62
+ *.spec
63
+
64
+ # Installer logs
65
+ pip-log.txt
66
+ pip-delete-this-directory.txt
67
+
68
+ # Unit test / coverage reports
69
+ htmlcov/
70
+ .tox/
71
+ .coverage
72
+ .coverage.*
73
+ .cache
74
+ nosetests.xml
75
+ coverage.xml
76
+ *.cover
77
+ .hypothesis/
78
+ .pytest_cache/
79
+
80
+ # Translations
81
+ *.mo
82
+ *.pot
83
+
84
+ # Django stuff:
85
+ *.log
86
+ local_settings.py
87
+ db.sqlite3
88
+
89
+ # Flask stuff:
90
+ instance/
91
+ .webassets-cache
92
+
93
+ # Scrapy stuff:
94
+ .scrapy
95
+
96
+ # Sphinx documentation
97
+ docs/_build/
98
+
99
+ # PyBuilder
100
+ target/
101
+
102
+ # Jupyter Notebook
103
+ .ipynb_checkpoints
104
+
105
+ # pyenv
106
+ .python-version
107
+
108
+ # celery beat schedule file
109
+ celerybeat-schedule
110
+
111
+ # SageMath parsed files
112
+ *.sage.py
113
+
114
+ # Environments
115
+ .env
116
+ .venv
117
+ env/
118
+ venv/
119
+ ENV/
120
+ env.bak/
121
+ venv.bak/
122
+
123
+ # Spyder project settings
124
+ .spyderproject
125
+ .spyproject
126
+
127
+ # Rope project settings
128
+ .ropeproject
129
+
130
+ # mkdocs documentation
131
+ /site
132
+
133
+ # mypy
134
+ .mypy_cache/
135
+
136
+ # meta file
137
+ data_list/*.txt
138
+
139
+ weights/
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
MODEL_ZOO.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## DDColor Model Zoo
2
+
3
+ | Model | Description | Note |
4
+ | ---------------------- | :------------------ | :-----|
5
+ | [ddcolor_paper.pth](https://huggingface.co/piddnad/DDColor-models/resolve/main/ddcolor_paper.pth) | DDColor-L trained on ImageNet | paper model, use it only if you want to reproduce some of the images in the paper.
6
+ | [ddcolor_modelscope.pth](https://huggingface.co/piddnad/DDColor-models/resolve/main/ddcolor_modelscope.pth) (***default***) | DDColor-L trained on ImageNet | We trained this model using the same data cleaning scheme as [BigColor](https://github.com/KIMGEONUNG/BigColor/issues/2#issuecomment-1196287574), so it can get the best qualitative results with little degrading FID performance. Use this model by default if you want to test images outside the ImageNet. It can also be easily downloaded through ModelScope [in this way](README.md#inference-with-modelscope-library).
7
+ | [ddcolor_artistic.pth](https://huggingface.co/piddnad/DDColor-models/resolve/main/ddcolor_artistic.pth) | DDColor-L trained on ImageNet + private data | We trained this model with an extended dataset containing many high-quality artistic images. Also, we didn't use colorfulness loss during training, so there may be fewer unreasonable color artifacts. Use this model if you want to try different colorization results.
8
+ | [ddcolor_paper_tiny.pth](https://huggingface.co/piddnad/DDColor-models/resolve/main/ddcolor_paper_tiny.pth) | DDColor-T trained on ImageNet | The most lightweight version of ddcolor model, using the same training scheme as ddcolor_paper.
9
+
10
+ ## Discussions
11
+
12
+ * About Colorfulness Loss (CL): CL can encourage more "colorful" results and help improve CF scores, however, it sometimes leads to the generation of unpleasant color blocks (eg. red color artifacts). If something goes wrong, I personally recommend trying to remove it during training.
13
+
README.md CHANGED
@@ -1,12 +1,206 @@
1
  ---
2
  title: DDColor
3
- emoji: 🚀
4
- colorFrom: gray
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 5.23.3
8
- app_file: app.py
9
- pinned: false
10
  ---
 
 
 
 
 
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: DDColor
3
+ app_file: gradio_app.py
 
 
4
  sdk: gradio
5
+ sdk_version: 5.21.0
 
 
6
  ---
7
+ # 🎨 DDColor
8
+ [![arXiv](https://img.shields.io/badge/arXiv-2212.11613-b31b1b.svg)](https://arxiv.org/abs/2212.11613)
9
+ [![HuggingFace](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-FF8000)](https://huggingface.co/piddnad/DDColor-models)
10
+ [![ModelScope demo](https://img.shields.io/badge/%F0%9F%91%BE%20ModelScope-Demo-8A2BE2)](https://www.modelscope.cn/models/damo/cv_ddcolor_image-colorization/summary)
11
+ [![Replicate](https://replicate.com/piddnad/ddcolor/badge)](https://replicate.com/piddnad/ddcolor)
12
+ ![visitors](https://visitor-badge.laobi.icu/badge?page_id=piddnad/DDColor)
13
 
14
+ Official PyTorch implementation of ICCV 2023 Paper "DDColor: Towards Photo-Realistic Image Colorization via Dual Decoders".
15
+
16
+ > Xiaoyang Kang, Tao Yang, Wenqi Ouyang, Peiran Ren, Lingzhi Li, Xuansong Xie
17
+ > *DAMO Academy, Alibaba Group*
18
+
19
+ 🪄 DDColor can provide vivid and natural colorization for historical black and white old photos.
20
+
21
+ <p align="center">
22
+ <img src="assets/teaser.png" width="100%">
23
+ </p>
24
+
25
+ 🎲 It can even colorize/recolor landscapes from anime games, transforming your animated scenery into a realistic real-life style! (Image source: Genshin Impact)
26
+
27
+ <p align="center">
28
+ <img src="assets/anime_landscapes.png" width="100%">
29
+ </p>
30
+
31
+
32
+ ## News
33
+ - [2024-01-28] Support inference via 🤗 Hugging Face! Thanks @[Niels](https://github.com/NielsRogge) for the suggestion and example code and @[Skwara](https://github.com/Skwarson96) for fixing bug.
34
+ - [2024-01-18] Add Replicate demo and API! Thanks @[Chenxi](https://github.com/chenxwh).
35
+ - [2023-12-13] Release the DDColor-tiny pre-trained model!
36
+ - [2023-09-07] Add the Model Zoo and release three pretrained models!
37
+ - [2023-05-15] Code release for training and inference!
38
+ - [2023-05-05] The online demo is available!
39
+
40
+
41
+ ## Online Demo
42
+ Try our online demos at [ModelScope](https://www.modelscope.cn/models/damo/cv_ddcolor_image-colorization/summary) and [Replicate](https://replicate.com/piddnad/ddcolor).
43
+
44
+
45
+ ## Methods
46
+ *In short:* DDColor uses multi-scale visual features to optimize **learnable color tokens** (i.e. color queries) and achieves state-of-the-art performance on automatic image colorization.
47
+
48
+ <p align="center">
49
+ <img src="assets/network_arch.jpg" width="100%">
50
+ </p>
51
+
52
+
53
+ ## Installation
54
+ ### Requirements
55
+ - Python >= 3.7
56
+ - PyTorch >= 1.7
57
+
58
+ ### Installation with conda (recommended)
59
+
60
+ ```sh
61
+ conda create -n ddcolor python=3.9
62
+ conda activate ddcolor
63
+ pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu118
64
+
65
+ pip install -r requirements.txt
66
+
67
+ # Install basicsr, only required for training
68
+ python3 setup.py develop
69
+ ```
70
+
71
+ ## Quick Start
72
+ ### Inference Using Local Script (No `basicsr` Required)
73
+ 1. Download the pretrained model:
74
+
75
+ ```python
76
+ from modelscope.hub.snapshot_download import snapshot_download
77
+
78
+ model_dir = snapshot_download('damo/cv_ddcolor_image-colorization', cache_dir='./modelscope')
79
+ print('model assets saved to %s' % model_dir)
80
+ ```
81
+
82
+ 2. Run inference with
83
+
84
+ ```sh
85
+ python infer.py --model_path ./modelscope/damo/cv_ddcolor_image-colorization/pytorch_model.pt --input ./assets/test_images
86
+ ```
87
+ or
88
+ ```sh
89
+ sh scripts/inference.sh
90
+ ```
91
+
92
+ ### Inference Using Hugging Face
93
+ Load the model via Hugging Face Hub:
94
+
95
+ ```python
96
+ from infer_hf import DDColorHF
97
+
98
+ ddcolor_paper_tiny = DDColorHF.from_pretrained("piddnad/ddcolor_paper_tiny")
99
+ ddcolor_paper = DDColorHF.from_pretrained("piddnad/ddcolor_paper")
100
+ ddcolor_modelscope = DDColorHF.from_pretrained("piddnad/ddcolor_modelscope")
101
+ ddcolor_artistic = DDColorHF.from_pretrained("piddnad/ddcolor_artistic")
102
+ ```
103
+
104
+ Check `infer_hf.py` for the details of the inference, or directly perform model inference by running:
105
+
106
+ ```sh
107
+ python infer_hf.py --model_name ddcolor_modelscope --input ./assets/test_images
108
+ # model_name: [ddcolor_paper | ddcolor_modelscope | ddcolor_artistic | ddcolor_paper_tiny]
109
+ ```
110
+
111
+ ### Inference Using ModelScope
112
+ 1. Install modelscope:
113
+
114
+ ```sh
115
+ pip install modelscope
116
+ ```
117
+
118
+ 2. Run inference:
119
+
120
+ ```python
121
+ import cv2
122
+ from modelscope.outputs import OutputKeys
123
+ from modelscope.pipelines import pipeline
124
+ from modelscope.utils.constant import Tasks
125
+
126
+ img_colorization = pipeline(Tasks.image_colorization, model='damo/cv_ddcolor_image-colorization')
127
+ result = img_colorization('https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/audrey_hepburn.jpg')
128
+ cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG])
129
+ ```
130
+
131
+ This code will automatically download the `ddcolor_modelscope` model (see [ModelZoo](#model-zoo)) and performs inference. The model file `pytorch_model.pt` can be found in the local path `~/.cache/modelscope/hub/damo`.
132
+
133
+ ### Gradio Demo
134
+ Install the gradio and other required libraries:
135
+
136
+ ```sh
137
+ pip install gradio gradio_imageslider timm
138
+ ```
139
+
140
+ Then, you can run the demo with the following command:
141
+
142
+ ```sh
143
+ python gradio_app.py
144
+ ```
145
+
146
+ ## Model Zoo
147
+ We provide several different versions of pretrained models, please check out [Model Zoo](MODEL_ZOO.md).
148
+
149
+
150
+ ## Train
151
+ 1. Dataset Preparation: Download the [ImageNet](https://www.image-net.org/) dataset or create a custom dataset. Use this script to obtain the dataset list file:
152
+
153
+ ```sh
154
+ python data_list/get_meta_file.py
155
+ ```
156
+
157
+ 2. Download the pretrained weights for [ConvNeXt](https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth) and [InceptionV3](https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth) and place them in the `pretrain` folder.
158
+
159
+ 3. Specify 'meta_info_file' and other options in `options/train/train_ddcolor.yml`.
160
+
161
+ 4. Start training:
162
+
163
+ ```sh
164
+ sh scripts/train.sh
165
+ ```
166
+
167
+ ## ONNX export
168
+ Support for ONNX model exports is available.
169
+
170
+ 1. Install dependencies:
171
+
172
+ ```sh
173
+ pip install onnx==1.16.1 onnxruntime==1.19.2 onnxsim==0.4.36
174
+ ```
175
+
176
+ 2. Usage example:
177
+
178
+ ```sh
179
+ python export.py
180
+ usage: export.py [-h] [--input_size INPUT_SIZE] [--batch_size BATCH_SIZE] --model_path MODEL_PATH [--model_size MODEL_SIZE]
181
+ [--decoder_type DECODER_TYPE] [--export_path EXPORT_PATH] [--opset OPSET]
182
+ ```
183
+
184
+ Demo of ONNX export using a `ddcolor_paper_tiny` model is available [here](notebooks/colorization_pipeline_onnxruntime.ipynb).
185
+
186
+
187
+ ## Citation
188
+
189
+ If our work is helpful for your research, please consider citing:
190
+
191
+ ```
192
+ @inproceedings{kang2023ddcolor,
193
+ title={DDColor: Towards Photo-Realistic Image Colorization via Dual Decoders},
194
+ author={Kang, Xiaoyang and Yang, Tao and Ouyang, Wenqi and Ren, Peiran and Li, Lingzhi and Xie, Xuansong},
195
+ booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
196
+ pages={328--338},
197
+ year={2023}
198
+ }
199
+ ```
200
+
201
+ ## Acknowledgments
202
+ We thank the authors of BasicSR for the awesome training pipeline.
203
+
204
+ > Xintao Wang, Ke Yu, Kelvin C.K. Chan, Chao Dong and Chen Change Loy. BasicSR: Open Source Image and Video Restoration Toolbox. https://github.com/xinntao/BasicSR, 2020.
205
+
206
+ Some codes are adapted from [ColorFormer](https://github.com/jixiaozhong/ColorFormer), [BigColor](https://github.com/KIMGEONUNG/BigColor), [ConvNeXt](https://github.com/facebookresearch/ConvNeXt), [Mask2Former](https://github.com/facebookresearch/Mask2Former), and [DETR](https://github.com/facebookresearch/detr). Thanks for their excellent work!
VERSION ADDED
@@ -0,0 +1 @@
 
 
1
+ 1.3.4.6
basicsr/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/xinntao/BasicSR
2
+ # flake8: noqa
3
+ from .archs import *
4
+ from .data import *
5
+ from .losses import *
6
+ from .metrics import *
7
+ from .models import *
8
+ # from .ops import *
9
+ # from .test import *
10
+ from .train import *
11
+ from .utils import *
12
+ try:
13
+ from .version import __gitsha__, __version__
14
+ except:
15
+ pass
basicsr/archs/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from copy import deepcopy
3
+ from os import path as osp
4
+
5
+ from basicsr.utils import get_root_logger, scandir
6
+ from basicsr.utils.registry import ARCH_REGISTRY
7
+
8
+ __all__ = ['build_network']
9
+
10
+ # automatically scan and import arch modules for registry
11
+ # scan all the files under the 'archs' folder and collect files ending with
12
+ # '_arch.py'
13
+ arch_folder = osp.dirname(osp.abspath(__file__))
14
+ arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
15
+ # import all the arch modules
16
+ _arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames]
17
+
18
+
19
+ def build_network(opt):
20
+ opt = deepcopy(opt)
21
+ network_type = opt.pop('type')
22
+ net = ARCH_REGISTRY.get(network_type)(**opt)
23
+ logger = get_root_logger()
24
+ logger.info(f'Network [{net.__class__.__name__}] is created.')
25
+ return net
basicsr/archs/ddcolor_arch.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from basicsr.archs.ddcolor_arch_utils.unet import Hook, CustomPixelShuffle_ICNR, UnetBlockWide, NormType, custom_conv_layer
5
+ from basicsr.archs.ddcolor_arch_utils.convnext import ConvNeXt
6
+ from basicsr.archs.ddcolor_arch_utils.transformer_utils import SelfAttentionLayer, CrossAttentionLayer, FFNLayer, MLP
7
+ from basicsr.archs.ddcolor_arch_utils.position_encoding import PositionEmbeddingSine
8
+ from basicsr.archs.ddcolor_arch_utils.transformer import Transformer
9
+ from basicsr.utils.registry import ARCH_REGISTRY
10
+
11
+
12
+ @ARCH_REGISTRY.register()
13
+ class DDColor(nn.Module):
14
+
15
+ def __init__(self,
16
+ encoder_name='convnext-l',
17
+ decoder_name='MultiScaleColorDecoder',
18
+ num_input_channels=3,
19
+ input_size=(256, 256),
20
+ nf=512,
21
+ num_output_channels=3,
22
+ last_norm='Weight',
23
+ do_normalize=False,
24
+ num_queries=256,
25
+ num_scales=3,
26
+ dec_layers=9,
27
+ encoder_from_pretrain=False):
28
+ super().__init__()
29
+
30
+ self.encoder = Encoder(encoder_name, ['norm0', 'norm1', 'norm2', 'norm3'], from_pretrain=encoder_from_pretrain)
31
+ self.encoder.eval()
32
+ test_input = torch.randn(1, num_input_channels, *input_size)
33
+ self.encoder(test_input)
34
+
35
+ self.decoder = Decoder(
36
+ self.encoder.hooks,
37
+ nf=nf,
38
+ last_norm=last_norm,
39
+ num_queries=num_queries,
40
+ num_scales=num_scales,
41
+ dec_layers=dec_layers,
42
+ decoder_name=decoder_name
43
+ )
44
+ self.refine_net = nn.Sequential(custom_conv_layer(num_queries + 3, num_output_channels, ks=1, use_activ=False, norm_type=NormType.Spectral))
45
+
46
+ self.do_normalize = do_normalize
47
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
48
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
49
+
50
+ def normalize(self, img):
51
+ return (img - self.mean) / self.std
52
+
53
+ def denormalize(self, img):
54
+ return img * self.std + self.mean
55
+
56
+ def forward(self, x):
57
+ if x.shape[1] == 3:
58
+ x = self.normalize(x)
59
+
60
+ self.encoder(x)
61
+ out_feat = self.decoder()
62
+ coarse_input = torch.cat([out_feat, x], dim=1)
63
+ out = self.refine_net(coarse_input)
64
+
65
+ if self.do_normalize:
66
+ out = self.denormalize(out)
67
+ return out
68
+
69
+
70
+ class Decoder(nn.Module):
71
+
72
+ def __init__(self,
73
+ hooks,
74
+ nf=512,
75
+ blur=True,
76
+ last_norm='Weight',
77
+ num_queries=256,
78
+ num_scales=3,
79
+ dec_layers=9,
80
+ decoder_name='MultiScaleColorDecoder'):
81
+ super().__init__()
82
+ self.hooks = hooks
83
+ self.nf = nf
84
+ self.blur = blur
85
+ self.last_norm = getattr(NormType, last_norm)
86
+ self.decoder_name = decoder_name
87
+
88
+ self.layers = self.make_layers()
89
+ embed_dim = nf // 2
90
+
91
+ self.last_shuf = CustomPixelShuffle_ICNR(embed_dim, embed_dim, blur=self.blur, norm_type=self.last_norm, scale=4)
92
+
93
+ if self.decoder_name == 'MultiScaleColorDecoder':
94
+ self.color_decoder = MultiScaleColorDecoder(
95
+ in_channels=[512, 512, 256],
96
+ num_queries=num_queries,
97
+ num_scales=num_scales,
98
+ dec_layers=dec_layers,
99
+ )
100
+ else:
101
+ self.color_decoder = SingleColorDecoder(
102
+ in_channels=hooks[-1].feature.shape[1],
103
+ num_queries=num_queries,
104
+ )
105
+
106
+
107
+ def forward(self):
108
+ encode_feat = self.hooks[-1].feature
109
+ out0 = self.layers[0](encode_feat)
110
+ out1 = self.layers[1](out0)
111
+ out2 = self.layers[2](out1)
112
+ out3 = self.last_shuf(out2)
113
+
114
+ if self.decoder_name == 'MultiScaleColorDecoder':
115
+ out = self.color_decoder([out0, out1, out2], out3)
116
+ else:
117
+ out = self.color_decoder(out3, encode_feat)
118
+
119
+ return out
120
+
121
+ def make_layers(self):
122
+ decoder_layers = []
123
+
124
+ e_in_c = self.hooks[-1].feature.shape[1]
125
+ in_c = e_in_c
126
+
127
+ out_c = self.nf
128
+ setup_hooks = self.hooks[-2::-1]
129
+ for layer_index, hook in enumerate(setup_hooks):
130
+ feature_c = hook.feature.shape[1]
131
+ if layer_index == len(setup_hooks) - 1:
132
+ out_c = out_c // 2
133
+ decoder_layers.append(
134
+ UnetBlockWide(
135
+ in_c, feature_c, out_c, hook, blur=self.blur, self_attention=False, norm_type=NormType.Spectral))
136
+ in_c = out_c
137
+ return nn.Sequential(*decoder_layers)
138
+
139
+
140
+ class Encoder(nn.Module):
141
+
142
+ def __init__(self, encoder_name, hook_names, from_pretrain, **kwargs):
143
+ super().__init__()
144
+
145
+ if encoder_name == 'convnext-t' or encoder_name == 'convnext':
146
+ self.arch = ConvNeXt()
147
+ elif encoder_name == 'convnext-s':
148
+ self.arch = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768])
149
+ elif encoder_name == 'convnext-b':
150
+ self.arch = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024])
151
+ elif encoder_name == 'convnext-l':
152
+ self.arch = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536])
153
+ else:
154
+ raise NotImplementedError
155
+
156
+ self.encoder_name = encoder_name
157
+ self.hook_names = hook_names
158
+ self.hooks = self.setup_hooks()
159
+
160
+ if from_pretrain:
161
+ self.load_pretrain_model()
162
+
163
+ def setup_hooks(self):
164
+ hooks = [Hook(self.arch._modules[name]) for name in self.hook_names]
165
+ return hooks
166
+
167
+ def forward(self, x):
168
+ return self.arch(x)
169
+
170
+ def load_pretrain_model(self):
171
+ if self.encoder_name == 'convnext-t' or self.encoder_name == 'convnext':
172
+ self.load('pretrain/convnext_tiny_22k_224.pth')
173
+ elif self.encoder_name == 'convnext-s':
174
+ self.load('pretrain/convnext_small_22k_224.pth')
175
+ elif self.encoder_name == 'convnext-b':
176
+ self.load('pretrain/convnext_base_22k_224.pth')
177
+ elif self.encoder_name == 'convnext-l':
178
+ self.load('pretrain/convnext_large_22k_224.pth')
179
+ else:
180
+ raise NotImplementedError
181
+ print('Loaded pretrained convnext model.')
182
+
183
+ def load(self, path):
184
+ from basicsr.utils import get_root_logger
185
+ logger = get_root_logger()
186
+ if not path:
187
+ logger.info("No checkpoint found. Initializing model from scratch")
188
+ return
189
+ logger.info("[Encoder] Loading from {} ...".format(path))
190
+ checkpoint = torch.load(path, map_location=torch.device("cpu"))
191
+ checkpoint_state_dict = checkpoint['model'] if 'model' in checkpoint.keys() else checkpoint
192
+ incompatible = self.arch.load_state_dict(checkpoint_state_dict, strict=False)
193
+
194
+ if incompatible.missing_keys:
195
+ msg = "Some model parameters or buffers are not found in the checkpoint:\n"
196
+ msg += str(incompatible.missing_keys)
197
+ logger.warning(msg)
198
+ if incompatible.unexpected_keys:
199
+ msg = "The checkpoint state_dict contains keys that are not used by the model:\n"
200
+ msg += str(incompatible.unexpected_keys)
201
+ logger.warning(msg)
202
+
203
+
204
+ class MultiScaleColorDecoder(nn.Module):
205
+
206
+ def __init__(
207
+ self,
208
+ in_channels,
209
+ hidden_dim=256,
210
+ num_queries=100,
211
+ nheads=8,
212
+ dim_feedforward=2048,
213
+ dec_layers=9,
214
+ pre_norm=False,
215
+ color_embed_dim=256,
216
+ enforce_input_project=True,
217
+ num_scales=3
218
+ ):
219
+ super().__init__()
220
+
221
+ # positional encoding
222
+ N_steps = hidden_dim // 2
223
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
224
+
225
+ # define Transformer decoder here
226
+ self.num_heads = nheads
227
+ self.num_layers = dec_layers
228
+ self.transformer_self_attention_layers = nn.ModuleList()
229
+ self.transformer_cross_attention_layers = nn.ModuleList()
230
+ self.transformer_ffn_layers = nn.ModuleList()
231
+
232
+ for _ in range(self.num_layers):
233
+ self.transformer_self_attention_layers.append(
234
+ SelfAttentionLayer(
235
+ d_model=hidden_dim,
236
+ nhead=nheads,
237
+ dropout=0.0,
238
+ normalize_before=pre_norm,
239
+ )
240
+ )
241
+ self.transformer_cross_attention_layers.append(
242
+ CrossAttentionLayer(
243
+ d_model=hidden_dim,
244
+ nhead=nheads,
245
+ dropout=0.0,
246
+ normalize_before=pre_norm,
247
+ )
248
+ )
249
+ self.transformer_ffn_layers.append(
250
+ FFNLayer(
251
+ d_model=hidden_dim,
252
+ dim_feedforward=dim_feedforward,
253
+ dropout=0.0,
254
+ normalize_before=pre_norm,
255
+ )
256
+ )
257
+
258
+ self.decoder_norm = nn.LayerNorm(hidden_dim)
259
+
260
+ self.num_queries = num_queries
261
+ # learnable color query features
262
+ self.query_feat = nn.Embedding(num_queries, hidden_dim)
263
+ # learnable color query p.e.
264
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
265
+
266
+ # level embedding
267
+ self.num_feature_levels = num_scales
268
+ self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
269
+
270
+ # input projections
271
+ self.input_proj = nn.ModuleList()
272
+ for i in range(self.num_feature_levels):
273
+ if in_channels[i] != hidden_dim or enforce_input_project:
274
+ self.input_proj.append(nn.Conv2d(in_channels[i], hidden_dim, kernel_size=1))
275
+ nn.init.kaiming_uniform_(self.input_proj[-1].weight, a=1)
276
+ if self.input_proj[-1].bias is not None:
277
+ nn.init.constant_(self.input_proj[-1].bias, 0)
278
+ else:
279
+ self.input_proj.append(nn.Sequential())
280
+
281
+ # output FFNs
282
+ self.color_embed = MLP(hidden_dim, hidden_dim, color_embed_dim, 3)
283
+
284
+ def forward(self, x, img_features):
285
+ # x is a list of multi-scale feature
286
+ assert len(x) == self.num_feature_levels
287
+ src = []
288
+ pos = []
289
+
290
+ for i in range(self.num_feature_levels):
291
+ pos.append(self.pe_layer(x[i], None).flatten(2))
292
+ src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])
293
+
294
+ # flatten NxCxHxW to HWxNxC
295
+ pos[-1] = pos[-1].permute(2, 0, 1)
296
+ src[-1] = src[-1].permute(2, 0, 1)
297
+
298
+ _, bs, _ = src[0].shape
299
+
300
+ # QxNxC
301
+ query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
302
+ output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
303
+
304
+ for i in range(self.num_layers):
305
+ level_index = i % self.num_feature_levels
306
+ # attention: cross-attention first
307
+ output = self.transformer_cross_attention_layers[i](
308
+ output, src[level_index],
309
+ memory_mask=None,
310
+ memory_key_padding_mask=None,
311
+ pos=pos[level_index], query_pos=query_embed
312
+ )
313
+ output = self.transformer_self_attention_layers[i](
314
+ output, tgt_mask=None,
315
+ tgt_key_padding_mask=None,
316
+ query_pos=query_embed
317
+ )
318
+ # FFN
319
+ output = self.transformer_ffn_layers[i](
320
+ output
321
+ )
322
+
323
+ decoder_output = self.decoder_norm(output)
324
+ decoder_output = decoder_output.transpose(0, 1) # [N, bs, C] -> [bs, N, C]
325
+ color_embed = self.color_embed(decoder_output)
326
+ out = torch.einsum("bqc,bchw->bqhw", color_embed, img_features)
327
+
328
+ return out
329
+
330
+
331
+ class SingleColorDecoder(nn.Module):
332
+
333
+ def __init__(
334
+ self,
335
+ in_channels=768,
336
+ hidden_dim=256,
337
+ num_queries=256, # 100
338
+ nheads=8,
339
+ dropout=0.1,
340
+ dim_feedforward=2048,
341
+ enc_layers=0,
342
+ dec_layers=6,
343
+ pre_norm=False,
344
+ deep_supervision=True,
345
+ enforce_input_project=True,
346
+ ):
347
+
348
+ super().__init__()
349
+
350
+ N_steps = hidden_dim // 2
351
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
352
+
353
+ transformer = Transformer(
354
+ d_model=hidden_dim,
355
+ dropout=dropout,
356
+ nhead=nheads,
357
+ dim_feedforward=dim_feedforward,
358
+ num_encoder_layers=enc_layers,
359
+ num_decoder_layers=dec_layers,
360
+ normalize_before=pre_norm,
361
+ return_intermediate_dec=deep_supervision,
362
+ )
363
+ self.num_queries = num_queries
364
+ self.transformer = transformer
365
+
366
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
367
+
368
+ if in_channels != hidden_dim or enforce_input_project:
369
+ self.input_proj = nn.Conv2d(in_channels, hidden_dim, kernel_size=1)
370
+ nn.init.kaiming_uniform_(self.input_proj.weight, a=1)
371
+ if self.input_proj.bias is not None:
372
+ nn.init.constant_(self.input_proj.bias, 0)
373
+ else:
374
+ self.input_proj = nn.Sequential()
375
+
376
+
377
+ def forward(self, img_features, encode_feat):
378
+ pos = self.pe_layer(encode_feat)
379
+ src = encode_feat
380
+ mask = None
381
+ hs, memory = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos)
382
+ color_embed = hs[-1]
383
+ color_preds = torch.einsum('bqc,bchw->bqhw', color_embed, img_features)
384
+ return color_preds
385
+
basicsr/archs/ddcolor_arch_utils/__int__.py ADDED
File without changes
basicsr/archs/ddcolor_arch_utils/convnext.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from timm.models.layers import trunc_normal_, DropPath
13
+
14
+ class Block(nn.Module):
15
+ r""" ConvNeXt Block. There are two equivalent implementations:
16
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
17
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
18
+ We use (2) as we find it slightly faster in PyTorch
19
+
20
+ Args:
21
+ dim (int): Number of input channels.
22
+ drop_path (float): Stochastic depth rate. Default: 0.0
23
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
24
+ """
25
+ def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
26
+ super().__init__()
27
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
28
+ self.norm = LayerNorm(dim, eps=1e-6)
29
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
30
+ self.act = nn.GELU()
31
+ self.pwconv2 = nn.Linear(4 * dim, dim)
32
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
33
+ requires_grad=True) if layer_scale_init_value > 0 else None
34
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
35
+
36
+ def forward(self, x):
37
+ input = x
38
+ x = self.dwconv(x)
39
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
40
+ x = self.norm(x)
41
+ x = self.pwconv1(x)
42
+ x = self.act(x)
43
+ x = self.pwconv2(x)
44
+ if self.gamma is not None:
45
+ x = self.gamma * x
46
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
47
+
48
+ x = input + self.drop_path(x)
49
+ return x
50
+
51
+ class ConvNeXt(nn.Module):
52
+ r""" ConvNeXt
53
+ A PyTorch impl of : `A ConvNet for the 2020s` -
54
+ https://arxiv.org/pdf/2201.03545.pdf
55
+ Args:
56
+ in_chans (int): Number of input image channels. Default: 3
57
+ num_classes (int): Number of classes for classification head. Default: 1000
58
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
59
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
60
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
61
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
62
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
63
+ """
64
+ def __init__(self, in_chans=3, num_classes=1000,
65
+ depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
66
+ layer_scale_init_value=1e-6, head_init_scale=1.,
67
+ ):
68
+ super().__init__()
69
+
70
+ self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
71
+ stem = nn.Sequential(
72
+ nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
73
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
74
+ )
75
+ self.downsample_layers.append(stem)
76
+ for i in range(3):
77
+ downsample_layer = nn.Sequential(
78
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
79
+ nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
80
+ )
81
+ self.downsample_layers.append(downsample_layer)
82
+
83
+ self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
84
+ dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
85
+ cur = 0
86
+ for i in range(4):
87
+ stage = nn.Sequential(
88
+ *[Block(dim=dims[i], drop_path=dp_rates[cur + j],
89
+ layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
90
+ )
91
+ self.stages.append(stage)
92
+ cur += depths[i]
93
+
94
+ # add norm layers for each output
95
+ out_indices = (0, 1, 2, 3)
96
+ for i in out_indices:
97
+ layer = LayerNorm(dims[i], eps=1e-6, data_format="channels_first")
98
+ # layer = nn.Identity()
99
+ layer_name = f'norm{i}'
100
+ self.add_module(layer_name, layer)
101
+
102
+ self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
103
+ # self.head_cls = nn.Linear(dims[-1], 4)
104
+
105
+ self.apply(self._init_weights)
106
+ # self.head_cls.weight.data.mul_(head_init_scale)
107
+ # self.head_cls.bias.data.mul_(head_init_scale)
108
+
109
+ def _init_weights(self, m):
110
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
111
+ trunc_normal_(m.weight, std=.02)
112
+ nn.init.constant_(m.bias, 0)
113
+
114
+ def forward_features(self, x):
115
+ for i in range(4):
116
+ x = self.downsample_layers[i](x)
117
+ x = self.stages[i](x)
118
+
119
+ # add extra norm
120
+ norm_layer = getattr(self, f'norm{i}')
121
+ # x = norm_layer(x)
122
+ norm_layer(x)
123
+
124
+ return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
125
+
126
+ def forward(self, x):
127
+ x = self.forward_features(x)
128
+ # x = self.head_cls(x)
129
+ return x
130
+
131
+ class LayerNorm(nn.Module):
132
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
133
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
134
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
135
+ with shape (batch_size, channels, height, width).
136
+ """
137
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
138
+ super().__init__()
139
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
140
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
141
+ self.eps = eps
142
+ self.data_format = data_format
143
+ if self.data_format not in ["channels_last", "channels_first"]:
144
+ raise NotImplementedError
145
+ self.normalized_shape = (normalized_shape, )
146
+
147
+ def forward(self, x):
148
+ if self.data_format == "channels_last": # B H W C
149
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
150
+ elif self.data_format == "channels_first": # B C H W
151
+ u = x.mean(1, keepdim=True)
152
+ s = (x - u).pow(2).mean(1, keepdim=True)
153
+ x = (x - u) / torch.sqrt(s + self.eps)
154
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
155
+ return x
basicsr/archs/ddcolor_arch_utils/position_encoding.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py
3
+ """
4
+ Various positional encodings for the transformer.
5
+ """
6
+ import math
7
+
8
+ import torch
9
+ from torch import nn
10
+
11
+
12
+ class PositionEmbeddingSine(nn.Module):
13
+ """
14
+ This is a more standard version of the position embedding, very similar to the one
15
+ used by the Attention is all you need paper, generalized to work on images.
16
+ """
17
+
18
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
19
+ super().__init__()
20
+ self.num_pos_feats = num_pos_feats
21
+ self.temperature = temperature
22
+ self.normalize = normalize
23
+ if scale is not None and normalize is False:
24
+ raise ValueError("normalize should be True if scale is passed")
25
+ if scale is None:
26
+ scale = 2 * math.pi
27
+ self.scale = scale
28
+
29
+ def forward(self, x, mask=None):
30
+ if mask is None:
31
+ mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
32
+ not_mask = ~mask
33
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
34
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
35
+ if self.normalize:
36
+ eps = 1e-6
37
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
38
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
39
+
40
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
41
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
42
+
43
+ pos_x = x_embed[:, :, :, None] / dim_t
44
+ pos_y = y_embed[:, :, :, None] / dim_t
45
+ pos_x = torch.stack(
46
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
47
+ ).flatten(3)
48
+ pos_y = torch.stack(
49
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
50
+ ).flatten(3)
51
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
52
+ return pos
basicsr/archs/ddcolor_arch_utils/transformer.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified from: https://github.com/facebookresearch/detr/blob/master/models/transformer.py
3
+ """
4
+ Transformer class.
5
+ Copy-paste from torch.nn.Transformer with modifications:
6
+ * positional encodings are passed in MHattention
7
+ * extra LN at the end of encoder is removed
8
+ * decoder returns a stack of activations from all decoding layers
9
+ """
10
+ import copy
11
+ from typing import List, Optional
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from torch import Tensor, nn
16
+
17
+
18
+ class Transformer(nn.Module):
19
+ def __init__(
20
+ self,
21
+ d_model=512,
22
+ nhead=8,
23
+ num_encoder_layers=6,
24
+ num_decoder_layers=6,
25
+ dim_feedforward=2048,
26
+ dropout=0.1,
27
+ activation="relu",
28
+ normalize_before=False,
29
+ return_intermediate_dec=False,
30
+ ):
31
+ super().__init__()
32
+
33
+ encoder_layer = TransformerEncoderLayer(
34
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
35
+ )
36
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
37
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
38
+
39
+ decoder_layer = TransformerDecoderLayer(
40
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
41
+ )
42
+ decoder_norm = nn.LayerNorm(d_model)
43
+ self.decoder = TransformerDecoder(
44
+ decoder_layer,
45
+ num_decoder_layers,
46
+ decoder_norm,
47
+ return_intermediate=return_intermediate_dec,
48
+ )
49
+
50
+ self._reset_parameters()
51
+
52
+ self.d_model = d_model
53
+ self.nhead = nhead
54
+
55
+ def _reset_parameters(self):
56
+ for p in self.parameters():
57
+ if p.dim() > 1:
58
+ nn.init.xavier_uniform_(p)
59
+
60
+ def forward(self, src, mask, query_embed, pos_embed):
61
+ # flatten NxCxHxW to HWxNxC
62
+ bs, c, h, w = src.shape
63
+ src = src.flatten(2).permute(2, 0, 1)
64
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
65
+ query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
66
+ if mask is not None:
67
+ mask = mask.flatten(1)
68
+
69
+ tgt = torch.zeros_like(query_embed)
70
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
71
+ hs = self.decoder(
72
+ tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed
73
+ )
74
+ return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
75
+
76
+
77
+ class TransformerEncoder(nn.Module):
78
+ def __init__(self, encoder_layer, num_layers, norm=None):
79
+ super().__init__()
80
+ self.layers = _get_clones(encoder_layer, num_layers)
81
+ self.num_layers = num_layers
82
+ self.norm = norm
83
+
84
+ def forward(
85
+ self,
86
+ src,
87
+ mask: Optional[Tensor] = None,
88
+ src_key_padding_mask: Optional[Tensor] = None,
89
+ pos: Optional[Tensor] = None,
90
+ ):
91
+ output = src
92
+
93
+ for layer in self.layers:
94
+ output = layer(
95
+ output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos
96
+ )
97
+
98
+ if self.norm is not None:
99
+ output = self.norm(output)
100
+
101
+ return output
102
+
103
+
104
+ class TransformerDecoder(nn.Module):
105
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
106
+ super().__init__()
107
+ self.layers = _get_clones(decoder_layer, num_layers)
108
+ self.num_layers = num_layers
109
+ self.norm = norm
110
+ self.return_intermediate = return_intermediate
111
+
112
+ def forward(
113
+ self,
114
+ tgt,
115
+ memory,
116
+ tgt_mask: Optional[Tensor] = None,
117
+ memory_mask: Optional[Tensor] = None,
118
+ tgt_key_padding_mask: Optional[Tensor] = None,
119
+ memory_key_padding_mask: Optional[Tensor] = None,
120
+ pos: Optional[Tensor] = None,
121
+ query_pos: Optional[Tensor] = None,
122
+ ):
123
+ output = tgt
124
+
125
+ intermediate = []
126
+
127
+ for layer in self.layers:
128
+ output = layer(
129
+ output,
130
+ memory,
131
+ tgt_mask=tgt_mask,
132
+ memory_mask=memory_mask,
133
+ tgt_key_padding_mask=tgt_key_padding_mask,
134
+ memory_key_padding_mask=memory_key_padding_mask,
135
+ pos=pos,
136
+ query_pos=query_pos,
137
+ )
138
+ if self.return_intermediate:
139
+ intermediate.append(self.norm(output))
140
+
141
+ if self.norm is not None:
142
+ output = self.norm(output)
143
+ if self.return_intermediate:
144
+ intermediate.pop()
145
+ intermediate.append(output)
146
+
147
+ if self.return_intermediate:
148
+ return torch.stack(intermediate)
149
+
150
+ return output.unsqueeze(0)
151
+
152
+
153
+ class TransformerEncoderLayer(nn.Module):
154
+ def __init__(
155
+ self,
156
+ d_model,
157
+ nhead,
158
+ dim_feedforward=2048,
159
+ dropout=0.1,
160
+ activation="relu",
161
+ normalize_before=False,
162
+ ):
163
+ super().__init__()
164
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
165
+ # Implementation of Feedforward model
166
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
167
+ self.dropout = nn.Dropout(dropout)
168
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
169
+
170
+ self.norm1 = nn.LayerNorm(d_model)
171
+ self.norm2 = nn.LayerNorm(d_model)
172
+ self.dropout1 = nn.Dropout(dropout)
173
+ self.dropout2 = nn.Dropout(dropout)
174
+
175
+ self.activation = _get_activation_fn(activation)
176
+ self.normalize_before = normalize_before
177
+
178
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
179
+ return tensor if pos is None else tensor + pos
180
+
181
+ def forward_post(
182
+ self,
183
+ src,
184
+ src_mask: Optional[Tensor] = None,
185
+ src_key_padding_mask: Optional[Tensor] = None,
186
+ pos: Optional[Tensor] = None,
187
+ ):
188
+ q = k = self.with_pos_embed(src, pos)
189
+ src2 = self.self_attn(
190
+ q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
191
+ )[0]
192
+ src = src + self.dropout1(src2)
193
+ src = self.norm1(src)
194
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
195
+ src = src + self.dropout2(src2)
196
+ src = self.norm2(src)
197
+ return src
198
+
199
+ def forward_pre(
200
+ self,
201
+ src,
202
+ src_mask: Optional[Tensor] = None,
203
+ src_key_padding_mask: Optional[Tensor] = None,
204
+ pos: Optional[Tensor] = None,
205
+ ):
206
+ src2 = self.norm1(src)
207
+ q = k = self.with_pos_embed(src2, pos)
208
+ src2 = self.self_attn(
209
+ q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
210
+ )[0]
211
+ src = src + self.dropout1(src2)
212
+ src2 = self.norm2(src)
213
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
214
+ src = src + self.dropout2(src2)
215
+ return src
216
+
217
+ def forward(
218
+ self,
219
+ src,
220
+ src_mask: Optional[Tensor] = None,
221
+ src_key_padding_mask: Optional[Tensor] = None,
222
+ pos: Optional[Tensor] = None,
223
+ ):
224
+ if self.normalize_before:
225
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
226
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
227
+
228
+
229
+ class TransformerDecoderLayer(nn.Module):
230
+ def __init__(
231
+ self,
232
+ d_model,
233
+ nhead,
234
+ dim_feedforward=2048,
235
+ dropout=0.1,
236
+ activation="relu",
237
+ normalize_before=False,
238
+ ):
239
+ super().__init__()
240
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
241
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
242
+ # Implementation of Feedforward model
243
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
244
+ self.dropout = nn.Dropout(dropout)
245
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
246
+
247
+ self.norm1 = nn.LayerNorm(d_model)
248
+ self.norm2 = nn.LayerNorm(d_model)
249
+ self.norm3 = nn.LayerNorm(d_model)
250
+ self.dropout1 = nn.Dropout(dropout)
251
+ self.dropout2 = nn.Dropout(dropout)
252
+ self.dropout3 = nn.Dropout(dropout)
253
+
254
+ self.activation = _get_activation_fn(activation)
255
+ self.normalize_before = normalize_before
256
+
257
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
258
+ return tensor if pos is None else tensor + pos
259
+
260
+ def forward_post(
261
+ self,
262
+ tgt,
263
+ memory,
264
+ tgt_mask: Optional[Tensor] = None,
265
+ memory_mask: Optional[Tensor] = None,
266
+ tgt_key_padding_mask: Optional[Tensor] = None,
267
+ memory_key_padding_mask: Optional[Tensor] = None,
268
+ pos: Optional[Tensor] = None,
269
+ query_pos: Optional[Tensor] = None,
270
+ ):
271
+ q = k = self.with_pos_embed(tgt, query_pos)
272
+ tgt2 = self.self_attn(
273
+ q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
274
+ )[0]
275
+ tgt = tgt + self.dropout1(tgt2)
276
+ tgt = self.norm1(tgt)
277
+ tgt2 = self.multihead_attn(
278
+ query=self.with_pos_embed(tgt, query_pos),
279
+ key=self.with_pos_embed(memory, pos),
280
+ value=memory,
281
+ attn_mask=memory_mask,
282
+ key_padding_mask=memory_key_padding_mask,
283
+ )[0]
284
+ tgt = tgt + self.dropout2(tgt2)
285
+ tgt = self.norm2(tgt)
286
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
287
+ tgt = tgt + self.dropout3(tgt2)
288
+ tgt = self.norm3(tgt)
289
+ return tgt
290
+
291
+ def forward_pre(
292
+ self,
293
+ tgt,
294
+ memory,
295
+ tgt_mask: Optional[Tensor] = None,
296
+ memory_mask: Optional[Tensor] = None,
297
+ tgt_key_padding_mask: Optional[Tensor] = None,
298
+ memory_key_padding_mask: Optional[Tensor] = None,
299
+ pos: Optional[Tensor] = None,
300
+ query_pos: Optional[Tensor] = None,
301
+ ):
302
+ tgt2 = self.norm1(tgt)
303
+ q = k = self.with_pos_embed(tgt2, query_pos)
304
+ tgt2 = self.self_attn(
305
+ q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
306
+ )[0]
307
+ tgt = tgt + self.dropout1(tgt2)
308
+ tgt2 = self.norm2(tgt)
309
+ tgt2 = self.multihead_attn(
310
+ query=self.with_pos_embed(tgt2, query_pos),
311
+ key=self.with_pos_embed(memory, pos),
312
+ value=memory,
313
+ attn_mask=memory_mask,
314
+ key_padding_mask=memory_key_padding_mask,
315
+ )[0]
316
+ tgt = tgt + self.dropout2(tgt2)
317
+ tgt2 = self.norm3(tgt)
318
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
319
+ tgt = tgt + self.dropout3(tgt2)
320
+ return tgt
321
+
322
+ def forward(
323
+ self,
324
+ tgt,
325
+ memory,
326
+ tgt_mask: Optional[Tensor] = None,
327
+ memory_mask: Optional[Tensor] = None,
328
+ tgt_key_padding_mask: Optional[Tensor] = None,
329
+ memory_key_padding_mask: Optional[Tensor] = None,
330
+ pos: Optional[Tensor] = None,
331
+ query_pos: Optional[Tensor] = None,
332
+ ):
333
+ if self.normalize_before:
334
+ return self.forward_pre(
335
+ tgt,
336
+ memory,
337
+ tgt_mask,
338
+ memory_mask,
339
+ tgt_key_padding_mask,
340
+ memory_key_padding_mask,
341
+ pos,
342
+ query_pos,
343
+ )
344
+ return self.forward_post(
345
+ tgt,
346
+ memory,
347
+ tgt_mask,
348
+ memory_mask,
349
+ tgt_key_padding_mask,
350
+ memory_key_padding_mask,
351
+ pos,
352
+ query_pos,
353
+ )
354
+
355
+
356
+ def _get_clones(module, N):
357
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
358
+
359
+
360
+ def _get_activation_fn(activation):
361
+ """Return an activation function given a string"""
362
+ if activation == "relu":
363
+ return F.relu
364
+ if activation == "gelu":
365
+ return F.gelu
366
+ if activation == "glu":
367
+ return F.glu
368
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
basicsr/archs/ddcolor_arch_utils/transformer_utils.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from torch import nn, Tensor
3
+ from torch.nn import functional as F
4
+
5
+ class SelfAttentionLayer(nn.Module):
6
+
7
+ def __init__(self, d_model, nhead, dropout=0.0,
8
+ activation="relu", normalize_before=False):
9
+ super().__init__()
10
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
11
+
12
+ self.norm = nn.LayerNorm(d_model)
13
+ self.dropout = nn.Dropout(dropout)
14
+
15
+ self.activation = _get_activation_fn(activation)
16
+ self.normalize_before = normalize_before
17
+
18
+ self._reset_parameters()
19
+
20
+ def _reset_parameters(self):
21
+ for p in self.parameters():
22
+ if p.dim() > 1:
23
+ nn.init.xavier_uniform_(p)
24
+
25
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
26
+ return tensor if pos is None else tensor + pos
27
+
28
+ def forward_post(self, tgt,
29
+ tgt_mask: Optional[Tensor] = None,
30
+ tgt_key_padding_mask: Optional[Tensor] = None,
31
+ query_pos: Optional[Tensor] = None):
32
+ q = k = self.with_pos_embed(tgt, query_pos)
33
+ tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
34
+ key_padding_mask=tgt_key_padding_mask)[0]
35
+ tgt = tgt + self.dropout(tgt2)
36
+ tgt = self.norm(tgt)
37
+
38
+ return tgt
39
+
40
+ def forward_pre(self, tgt,
41
+ tgt_mask: Optional[Tensor] = None,
42
+ tgt_key_padding_mask: Optional[Tensor] = None,
43
+ query_pos: Optional[Tensor] = None):
44
+ tgt2 = self.norm(tgt)
45
+ q = k = self.with_pos_embed(tgt2, query_pos)
46
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
47
+ key_padding_mask=tgt_key_padding_mask)[0]
48
+ tgt = tgt + self.dropout(tgt2)
49
+
50
+ return tgt
51
+
52
+ def forward(self, tgt,
53
+ tgt_mask: Optional[Tensor] = None,
54
+ tgt_key_padding_mask: Optional[Tensor] = None,
55
+ query_pos: Optional[Tensor] = None):
56
+ if self.normalize_before:
57
+ return self.forward_pre(tgt, tgt_mask,
58
+ tgt_key_padding_mask, query_pos)
59
+ return self.forward_post(tgt, tgt_mask,
60
+ tgt_key_padding_mask, query_pos)
61
+
62
+
63
+ class CrossAttentionLayer(nn.Module):
64
+
65
+ def __init__(self, d_model, nhead, dropout=0.0,
66
+ activation="relu", normalize_before=False):
67
+ super().__init__()
68
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
69
+
70
+ self.norm = nn.LayerNorm(d_model)
71
+ self.dropout = nn.Dropout(dropout)
72
+
73
+ self.activation = _get_activation_fn(activation)
74
+ self.normalize_before = normalize_before
75
+
76
+ self._reset_parameters()
77
+
78
+ def _reset_parameters(self):
79
+ for p in self.parameters():
80
+ if p.dim() > 1:
81
+ nn.init.xavier_uniform_(p)
82
+
83
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
84
+ return tensor if pos is None else tensor + pos
85
+
86
+ def forward_post(self, tgt, memory,
87
+ memory_mask: Optional[Tensor] = None,
88
+ memory_key_padding_mask: Optional[Tensor] = None,
89
+ pos: Optional[Tensor] = None,
90
+ query_pos: Optional[Tensor] = None):
91
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
92
+ key=self.with_pos_embed(memory, pos),
93
+ value=memory, attn_mask=memory_mask,
94
+ key_padding_mask=memory_key_padding_mask)[0]
95
+ tgt = tgt + self.dropout(tgt2)
96
+ tgt = self.norm(tgt)
97
+
98
+ return tgt
99
+
100
+ def forward_pre(self, tgt, memory,
101
+ memory_mask: Optional[Tensor] = None,
102
+ memory_key_padding_mask: Optional[Tensor] = None,
103
+ pos: Optional[Tensor] = None,
104
+ query_pos: Optional[Tensor] = None):
105
+ tgt2 = self.norm(tgt)
106
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
107
+ key=self.with_pos_embed(memory, pos),
108
+ value=memory, attn_mask=memory_mask,
109
+ key_padding_mask=memory_key_padding_mask)[0]
110
+ tgt = tgt + self.dropout(tgt2)
111
+
112
+ return tgt
113
+
114
+ def forward(self, tgt, memory,
115
+ memory_mask: Optional[Tensor] = None,
116
+ memory_key_padding_mask: Optional[Tensor] = None,
117
+ pos: Optional[Tensor] = None,
118
+ query_pos: Optional[Tensor] = None):
119
+ if self.normalize_before:
120
+ return self.forward_pre(tgt, memory, memory_mask,
121
+ memory_key_padding_mask, pos, query_pos)
122
+ return self.forward_post(tgt, memory, memory_mask,
123
+ memory_key_padding_mask, pos, query_pos)
124
+
125
+
126
+ class FFNLayer(nn.Module):
127
+
128
+ def __init__(self, d_model, dim_feedforward=2048, dropout=0.0,
129
+ activation="relu", normalize_before=False):
130
+ super().__init__()
131
+ # Implementation of Feedforward model
132
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
133
+ self.dropout = nn.Dropout(dropout)
134
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
135
+
136
+ self.norm = nn.LayerNorm(d_model)
137
+
138
+ self.activation = _get_activation_fn(activation)
139
+ self.normalize_before = normalize_before
140
+
141
+ self._reset_parameters()
142
+
143
+ def _reset_parameters(self):
144
+ for p in self.parameters():
145
+ if p.dim() > 1:
146
+ nn.init.xavier_uniform_(p)
147
+
148
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
149
+ return tensor if pos is None else tensor + pos
150
+
151
+ def forward_post(self, tgt):
152
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
153
+ tgt = tgt + self.dropout(tgt2)
154
+ tgt = self.norm(tgt)
155
+ return tgt
156
+
157
+ def forward_pre(self, tgt):
158
+ tgt2 = self.norm(tgt)
159
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
160
+ tgt = tgt + self.dropout(tgt2)
161
+ return tgt
162
+
163
+ def forward(self, tgt):
164
+ if self.normalize_before:
165
+ return self.forward_pre(tgt)
166
+ return self.forward_post(tgt)
167
+
168
+
169
+ def _get_activation_fn(activation):
170
+ """Return an activation function given a string"""
171
+ if activation == "relu":
172
+ return F.relu
173
+ if activation == "gelu":
174
+ return F.gelu
175
+ if activation == "glu":
176
+ return F.glu
177
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
178
+
179
+
180
+ class MLP(nn.Module):
181
+ """ Very simple multi-layer perceptron (also called FFN)"""
182
+
183
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
184
+ super().__init__()
185
+ self.num_layers = num_layers
186
+ h = [hidden_dim] * (num_layers - 1)
187
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
188
+
189
+ def forward(self, x):
190
+ for i, layer in enumerate(self.layers):
191
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
192
+ return x
basicsr/archs/ddcolor_arch_utils/unet.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+ import collections
6
+
7
+
8
+ NormType = Enum('NormType', 'Batch BatchZero Weight Spectral')
9
+
10
+
11
+ class Hook:
12
+ feature = None
13
+
14
+ def __init__(self, module):
15
+ self.hook = module.register_forward_hook(self.hook_fn)
16
+
17
+ def hook_fn(self, module, input, output):
18
+ if isinstance(output, torch.Tensor):
19
+ self.feature = output
20
+ elif isinstance(output, collections.OrderedDict):
21
+ self.feature = output['out']
22
+
23
+ def remove(self):
24
+ self.hook.remove()
25
+
26
+
27
+ class SelfAttention(nn.Module):
28
+ "Self attention layer for nd."
29
+
30
+ def __init__(self, n_channels: int):
31
+ super().__init__()
32
+ self.query = conv1d(n_channels, n_channels // 8)
33
+ self.key = conv1d(n_channels, n_channels // 8)
34
+ self.value = conv1d(n_channels, n_channels)
35
+ self.gamma = nn.Parameter(torch.tensor([0.]))
36
+
37
+ def forward(self, x):
38
+ #Notation from https://arxiv.org/pdf/1805.08318.pdf
39
+ size = x.size()
40
+ x = x.view(*size[:2], -1)
41
+ f, g, h = self.query(x), self.key(x), self.value(x)
42
+ beta = F.softmax(torch.bmm(f.permute(0, 2, 1).contiguous(), g), dim=1)
43
+ o = self.gamma * torch.bmm(h, beta) + x
44
+ return o.view(*size).contiguous()
45
+
46
+
47
+ def batchnorm_2d(nf: int, norm_type: NormType = NormType.Batch):
48
+ "A batchnorm2d layer with `nf` features initialized depending on `norm_type`."
49
+ bn = nn.BatchNorm2d(nf)
50
+ with torch.no_grad():
51
+ bn.bias.fill_(1e-3)
52
+ bn.weight.fill_(0. if norm_type == NormType.BatchZero else 1.)
53
+ return bn
54
+
55
+
56
+ def init_default(m: nn.Module, func=nn.init.kaiming_normal_) -> None:
57
+ "Initialize `m` weights with `func` and set `bias` to 0."
58
+ if func:
59
+ if hasattr(m, 'weight'): func(m.weight)
60
+ if hasattr(m, 'bias') and hasattr(m.bias, 'data'): m.bias.data.fill_(0.)
61
+ return m
62
+
63
+
64
+ def icnr(x, scale=2, init=nn.init.kaiming_normal_):
65
+ "ICNR init of `x`, with `scale` and `init` function."
66
+ ni, nf, h, w = x.shape
67
+ ni2 = int(ni / (scale**2))
68
+ k = init(torch.zeros([ni2, nf, h, w])).transpose(0, 1)
69
+ k = k.contiguous().view(ni2, nf, -1)
70
+ k = k.repeat(1, 1, scale**2)
71
+ k = k.contiguous().view([nf, ni, h, w]).transpose(0, 1)
72
+ x.data.copy_(k)
73
+
74
+
75
+ def conv1d(ni: int, no: int, ks: int = 1, stride: int = 1, padding: int = 0, bias: bool = False):
76
+ "Create and initialize a `nn.Conv1d` layer with spectral normalization."
77
+ conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias)
78
+ nn.init.kaiming_normal_(conv.weight)
79
+ if bias: conv.bias.data.zero_()
80
+ return nn.utils.spectral_norm(conv)
81
+
82
+
83
+ def custom_conv_layer(
84
+ ni: int,
85
+ nf: int,
86
+ ks: int = 3,
87
+ stride: int = 1,
88
+ padding: int = None,
89
+ bias: bool = None,
90
+ is_1d: bool = False,
91
+ norm_type=NormType.Batch,
92
+ use_activ: bool = True,
93
+ transpose: bool = False,
94
+ init=nn.init.kaiming_normal_,
95
+ self_attention: bool = False,
96
+ extra_bn: bool = False,
97
+ ):
98
+ "Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers."
99
+ if padding is None:
100
+ padding = (ks - 1) // 2 if not transpose else 0
101
+ bn = norm_type in (NormType.Batch, NormType.BatchZero) or extra_bn == True
102
+ if bias is None:
103
+ bias = not bn
104
+ conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d
105
+ conv = init_default(
106
+ conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding),
107
+ init,
108
+ )
109
+
110
+ if norm_type == NormType.Weight:
111
+ conv = nn.utils.weight_norm(conv)
112
+ elif norm_type == NormType.Spectral:
113
+ conv = nn.utils.spectral_norm(conv)
114
+ layers = [conv]
115
+ if use_activ:
116
+ layers.append(nn.ReLU(True))
117
+ if bn:
118
+ layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf))
119
+ if self_attention:
120
+ layers.append(SelfAttention(nf))
121
+ return nn.Sequential(*layers)
122
+
123
+
124
+ def conv_layer(ni: int,
125
+ nf: int,
126
+ ks: int = 3,
127
+ stride: int = 1,
128
+ padding: int = None,
129
+ bias: bool = None,
130
+ is_1d: bool = False,
131
+ norm_type=NormType.Batch,
132
+ use_activ: bool = True,
133
+ transpose: bool = False,
134
+ init=nn.init.kaiming_normal_,
135
+ self_attention: bool = False):
136
+ "Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers."
137
+ if padding is None: padding = (ks - 1) // 2 if not transpose else 0
138
+ bn = norm_type in (NormType.Batch, NormType.BatchZero)
139
+ if bias is None: bias = not bn
140
+ conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d
141
+ conv = init_default(conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding), init)
142
+ if norm_type == NormType.Weight: conv = nn.utils.weight_norm(conv)
143
+ elif norm_type == NormType.Spectral: conv = nn.utils.spectral_norm(conv)
144
+ layers = [conv]
145
+ if use_activ: layers.append(nn.ReLU(True))
146
+ if bn: layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf))
147
+ if self_attention: layers.append(SelfAttention(nf))
148
+ return nn.Sequential(*layers)
149
+
150
+
151
+ def _conv(ni: int, nf: int, ks: int = 3, stride: int = 1, **kwargs):
152
+ return conv_layer(ni, nf, ks=ks, stride=stride, norm_type=NormType.Spectral, **kwargs)
153
+
154
+
155
+ class CustomPixelShuffle_ICNR(nn.Module):
156
+ "Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init, and `weight_norm`."
157
+
158
+ def __init__(self,
159
+ ni: int,
160
+ nf: int = None,
161
+ scale: int = 2,
162
+ blur: bool = True,
163
+ norm_type=NormType.Spectral,
164
+ extra_bn=False):
165
+ super().__init__()
166
+ self.conv = custom_conv_layer(
167
+ ni, nf * (scale**2), ks=1, use_activ=False, norm_type=norm_type, extra_bn=extra_bn)
168
+ icnr(self.conv[0].weight)
169
+ self.shuf = nn.PixelShuffle(scale)
170
+ self.do_blur = blur
171
+ # Blurring over (h*w) kernel
172
+ # "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts"
173
+ # - https://arxiv.org/abs/1806.02658
174
+ self.pad = nn.ReplicationPad2d((1, 0, 1, 0))
175
+ self.blur = nn.AvgPool2d(2, stride=1)
176
+ self.relu = nn.ReLU(True)
177
+
178
+ def forward(self, x):
179
+ x = self.shuf(self.relu(self.conv(x)))
180
+ return self.blur(self.pad(x)) if self.do_blur else x
181
+
182
+
183
+ class UnetBlockWide(nn.Module):
184
+ "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
185
+
186
+ def __init__(self,
187
+ up_in_c: int,
188
+ x_in_c: int,
189
+ n_out: int,
190
+ hook,
191
+ blur: bool = False,
192
+ self_attention: bool = False,
193
+ norm_type=NormType.Spectral):
194
+ super().__init__()
195
+
196
+ self.hook = hook
197
+ up_out = n_out
198
+ self.shuf = CustomPixelShuffle_ICNR(up_in_c, up_out, blur=blur, norm_type=norm_type, extra_bn=True)
199
+ self.bn = batchnorm_2d(x_in_c)
200
+ ni = up_out + x_in_c
201
+ self.conv = custom_conv_layer(ni, n_out, norm_type=norm_type, self_attention=self_attention, extra_bn=True)
202
+ self.relu = nn.ReLU()
203
+
204
+ def forward(self, up_in):
205
+ s = self.hook.feature
206
+ up_out = self.shuf(up_in)
207
+ cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
208
+ return self.conv(cat_x)
basicsr/archs/ddcolor_arch_utils/util.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from skimage import color
4
+
5
+
6
+ def rgb2lab(img_rgb):
7
+ img_lab = color.rgb2lab(img_rgb)
8
+ return img_lab[:, :, :1], img_lab[:, :, 1:]
9
+
10
+
11
+ def tensor_lab2rgb(labs, illuminant="D65", observer="2"):
12
+ """
13
+ Args:
14
+ lab : (B, C, H, W)
15
+ Returns:
16
+ tuple : (B, C, H, W)
17
+ """
18
+ illuminants = \
19
+ {"A": {'2': (1.098466069456375, 1, 0.3558228003436005),
20
+ '10': (1.111420406956693, 1, 0.3519978321919493)},
21
+ "D50": {'2': (0.9642119944211994, 1, 0.8251882845188288),
22
+ '10': (0.9672062750333777, 1, 0.8142801513128616)},
23
+ "D55": {'2': (0.956797052643698, 1, 0.9214805860173273),
24
+ '10': (0.9579665682254781, 1, 0.9092525159847462)},
25
+ "D65": {'2': (0.95047, 1., 1.08883), # This was: `lab_ref_white`
26
+ '10': (0.94809667673716, 1, 1.0730513595166162)},
27
+ "D75": {'2': (0.9497220898840717, 1, 1.226393520724154),
28
+ '10': (0.9441713925645873, 1, 1.2064272211720228)},
29
+ "E": {'2': (1.0, 1.0, 1.0),
30
+ '10': (1.0, 1.0, 1.0)}}
31
+ xyz_from_rgb = np.array([[0.412453, 0.357580, 0.180423], [0.212671, 0.715160, 0.072169],
32
+ [0.019334, 0.119193, 0.950227]])
33
+
34
+ rgb_from_xyz = np.array([[3.240481340, -0.96925495, 0.055646640], [-1.53715152, 1.875990000, -0.20404134],
35
+ [-0.49853633, 0.041555930, 1.057311070]])
36
+ B, C, H, W = labs.shape
37
+ arrs = labs.permute((0, 2, 3, 1)).contiguous() # (B, 3, H, W) -> (B, H, W, 3)
38
+ L, a, b = arrs[:, :, :, 0:1], arrs[:, :, :, 1:2], arrs[:, :, :, 2:]
39
+ y = (L + 16.) / 116.
40
+ x = (a / 500.) + y
41
+ z = y - (b / 200.)
42
+ invalid = z.data < 0
43
+ z[invalid] = 0
44
+ xyz = torch.cat([x, y, z], dim=3)
45
+ mask = xyz.data > 0.2068966
46
+ mask_xyz = xyz.clone()
47
+ mask_xyz[mask] = torch.pow(xyz[mask], 3.0)
48
+ mask_xyz[~mask] = (xyz[~mask] - 16.0 / 116.) / 7.787
49
+ xyz_ref_white = illuminants[illuminant][observer]
50
+ for i in range(C):
51
+ mask_xyz[:, :, :, i] = mask_xyz[:, :, :, i] * xyz_ref_white[i]
52
+
53
+ rgb_trans = torch.mm(mask_xyz.view(-1, 3), torch.from_numpy(rgb_from_xyz).type_as(xyz)).view(B, H, W, C)
54
+ rgb = rgb_trans.permute((0, 3, 1, 2)).contiguous()
55
+ mask = rgb.data > 0.0031308
56
+ mask_rgb = rgb.clone()
57
+ mask_rgb[mask] = 1.055 * torch.pow(rgb[mask], 1 / 2.4) - 0.055
58
+ mask_rgb[~mask] = rgb[~mask] * 12.92
59
+ neg_mask = mask_rgb.data < 0
60
+ large_mask = mask_rgb.data > 1
61
+ mask_rgb[neg_mask] = 0
62
+ mask_rgb[large_mask] = 1
63
+ return mask_rgb
basicsr/archs/discriminator_arch.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models
4
+ import numpy as np
5
+
6
+ from basicsr.archs.ddcolor_arch_utils.unet import _conv
7
+ from basicsr.utils.registry import ARCH_REGISTRY
8
+
9
+
10
+ @ARCH_REGISTRY.register()
11
+ class DynamicUNetDiscriminator(nn.Module):
12
+
13
+ def __init__(self, n_channels: int = 3, nf: int = 256, n_blocks: int = 3):
14
+ super().__init__()
15
+ layers = [_conv(n_channels, nf, ks=4, stride=2)]
16
+ for i in range(n_blocks):
17
+ layers += [
18
+ _conv(nf, nf, ks=3, stride=1),
19
+ _conv(nf, nf * 2, ks=4, stride=2, self_attention=(i == 0)),
20
+ ]
21
+ nf *= 2
22
+ layers += [_conv(nf, nf, ks=3, stride=1), _conv(nf, 1, ks=4, bias=False, padding=0, use_activ=False)]
23
+ self.layers = nn.Sequential(*layers)
24
+
25
+ def forward(self, x):
26
+ out = self.layers(x)
27
+ out = out.view(out.size(0), -1)
28
+ return out
basicsr/archs/vgg_arch.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from collections import OrderedDict
4
+ from torch import nn as nn
5
+ from torchvision.models import vgg as vgg
6
+
7
+ from basicsr.utils.registry import ARCH_REGISTRY
8
+
9
+ VGG_PRETRAIN_PATH = {
10
+ 'vgg19': './pretrain/vgg19-dcbb9e9d.pth',
11
+ 'vgg16_bn': './pretrain/vgg16_bn-6c64b313.pth'
12
+ }
13
+
14
+ NAMES = {
15
+ 'vgg11': [
16
+ 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
17
+ 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
18
+ 'pool5'
19
+ ],
20
+ 'vgg13': [
21
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
22
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
23
+ 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
24
+ ],
25
+ 'vgg16': [
26
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
27
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
28
+ 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
29
+ 'pool5'
30
+ ],
31
+ 'vgg19': [
32
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
33
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
34
+ 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
35
+ 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
36
+ ]
37
+ }
38
+
39
+
40
+ def insert_bn(names):
41
+ """Insert bn layer after each conv.
42
+
43
+ Args:
44
+ names (list): The list of layer names.
45
+
46
+ Returns:
47
+ list: The list of layer names with bn layers.
48
+ """
49
+ names_bn = []
50
+ for name in names:
51
+ names_bn.append(name)
52
+ if 'conv' in name:
53
+ position = name.replace('conv', '')
54
+ names_bn.append('bn' + position)
55
+ return names_bn
56
+
57
+
58
+ @ARCH_REGISTRY.register()
59
+ class VGGFeatureExtractor(nn.Module):
60
+ """VGG network for feature extraction.
61
+
62
+ In this implementation, we allow users to choose whether use normalization
63
+ in the input feature and the type of vgg network. Note that the pretrained
64
+ path must fit the vgg type.
65
+
66
+ Args:
67
+ layer_name_list (list[str]): Forward function returns the corresponding
68
+ features according to the layer_name_list.
69
+ Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
70
+ vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
71
+ use_input_norm (bool): If True, normalize the input image. Importantly,
72
+ the input feature must in the range [0, 1]. Default: True.
73
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
74
+ Default: False.
75
+ requires_grad (bool): If true, the parameters of VGG network will be
76
+ optimized. Default: False.
77
+ remove_pooling (bool): If true, the max pooling operations in VGG net
78
+ will be removed. Default: False.
79
+ pooling_stride (int): The stride of max pooling operation. Default: 2.
80
+ """
81
+
82
+ def __init__(self,
83
+ layer_name_list,
84
+ vgg_type='vgg19',
85
+ use_input_norm=True,
86
+ range_norm=False,
87
+ requires_grad=False,
88
+ remove_pooling=False,
89
+ pooling_stride=2):
90
+ super(VGGFeatureExtractor, self).__init__()
91
+
92
+ self.layer_name_list = layer_name_list
93
+ self.use_input_norm = use_input_norm
94
+ self.range_norm = range_norm
95
+
96
+ self.names = NAMES[vgg_type.replace('_bn', '')]
97
+ if 'bn' in vgg_type:
98
+ self.names = insert_bn(self.names)
99
+
100
+ # only borrow layers that will be used to avoid unused params
101
+ max_idx = 0
102
+ for v in layer_name_list:
103
+ idx = self.names.index(v)
104
+ if idx > max_idx:
105
+ max_idx = idx
106
+
107
+ if os.path.exists(VGG_PRETRAIN_PATH[vgg_type]):
108
+ vgg_net = getattr(vgg, vgg_type)(pretrained=False)
109
+ state_dict = torch.load(VGG_PRETRAIN_PATH[vgg_type], map_location=lambda storage, loc: storage)
110
+ vgg_net.load_state_dict(state_dict)
111
+ else:
112
+ vgg_net = getattr(vgg, vgg_type)(pretrained=True)
113
+
114
+ features = vgg_net.features[:max_idx + 1]
115
+
116
+ modified_net = OrderedDict()
117
+ for k, v in zip(self.names, features):
118
+ if 'pool' in k:
119
+ # if remove_pooling is true, pooling operation will be removed
120
+ if remove_pooling:
121
+ continue
122
+ else:
123
+ # in some cases, we may want to change the default stride
124
+ modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
125
+ else:
126
+ modified_net[k] = v
127
+
128
+ self.vgg_net = nn.Sequential(modified_net)
129
+
130
+ if not requires_grad:
131
+ self.vgg_net.eval()
132
+ for param in self.parameters():
133
+ param.requires_grad = False
134
+ else:
135
+ self.vgg_net.train()
136
+ for param in self.parameters():
137
+ param.requires_grad = True
138
+
139
+ if self.use_input_norm:
140
+ # the mean is for image with range [0, 1]
141
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
142
+ # the std is for image with range [0, 1]
143
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
144
+
145
+ def forward(self, x):
146
+ """Forward function.
147
+
148
+ Args:
149
+ x (Tensor): Input tensor with shape (n, c, h, w).
150
+
151
+ Returns:
152
+ Tensor: Forward results.
153
+ """
154
+ if self.range_norm:
155
+ x = (x + 1) / 2
156
+ if self.use_input_norm:
157
+ x = (x - self.mean) / self.std
158
+
159
+ output = {}
160
+ for key, layer in self.vgg_net._modules.items():
161
+ x = layer(x)
162
+ if key in self.layer_name_list:
163
+ output[key] = x.clone()
164
+
165
+ return output
basicsr/data/__init__.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import numpy as np
3
+ import random
4
+ import torch
5
+ import torch.utils.data
6
+ from copy import deepcopy
7
+ from functools import partial
8
+ from os import path as osp
9
+
10
+ from basicsr.data.prefetch_dataloader import PrefetchDataLoader
11
+ from basicsr.utils import get_root_logger, scandir
12
+ from basicsr.utils.dist_util import get_dist_info
13
+ from basicsr.utils.registry import DATASET_REGISTRY
14
+
15
+ __all__ = ['build_dataset', 'build_dataloader']
16
+
17
+ # automatically scan and import dataset modules for registry
18
+ # scan all the files under the data folder with '_dataset' in file names
19
+ data_folder = osp.dirname(osp.abspath(__file__))
20
+ dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
21
+ # import all the dataset modules
22
+ _dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
23
+
24
+
25
+ def build_dataset(dataset_opt):
26
+ """Build dataset from options.
27
+
28
+ Args:
29
+ dataset_opt (dict): Configuration for dataset. It must contain:
30
+ name (str): Dataset name.
31
+ type (str): Dataset type.
32
+ """
33
+ dataset_opt = deepcopy(dataset_opt)
34
+ dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
35
+ logger = get_root_logger()
36
+ logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.')
37
+ return dataset
38
+
39
+
40
+ def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
41
+ """Build dataloader.
42
+
43
+ Args:
44
+ dataset (torch.utils.data.Dataset): Dataset.
45
+ dataset_opt (dict): Dataset options. It contains the following keys:
46
+ phase (str): 'train' or 'val'.
47
+ num_worker_per_gpu (int): Number of workers for each GPU.
48
+ batch_size_per_gpu (int): Training batch size for each GPU.
49
+ num_gpu (int): Number of GPUs. Used only in the train phase.
50
+ Default: 1.
51
+ dist (bool): Whether in distributed training. Used only in the train
52
+ phase. Default: False.
53
+ sampler (torch.utils.data.sampler): Data sampler. Default: None.
54
+ seed (int | None): Seed. Default: None
55
+ """
56
+ phase = dataset_opt['phase']
57
+ rank, _ = get_dist_info()
58
+ if phase == 'train':
59
+ if dist: # distributed training
60
+ batch_size = dataset_opt['batch_size_per_gpu']
61
+ num_workers = dataset_opt['num_worker_per_gpu']
62
+ else: # non-distributed training
63
+ multiplier = 1 if num_gpu == 0 else num_gpu
64
+ batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
65
+ num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
66
+ dataloader_args = dict(
67
+ dataset=dataset,
68
+ batch_size=batch_size,
69
+ shuffle=False,
70
+ num_workers=num_workers,
71
+ sampler=sampler,
72
+ drop_last=True)
73
+ if sampler is None:
74
+ dataloader_args['shuffle'] = True
75
+ dataloader_args['worker_init_fn'] = partial(
76
+ worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
77
+ elif phase in ['val', 'test']: # validation
78
+ dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
79
+ else:
80
+ raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.")
81
+
82
+ dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
83
+ dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False)
84
+
85
+ prefetch_mode = dataset_opt.get('prefetch_mode')
86
+ if prefetch_mode == 'cpu': # CPUPrefetcher
87
+ num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
88
+ logger = get_root_logger()
89
+ logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}')
90
+ return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
91
+ else:
92
+ # prefetch_mode=None: Normal dataloader
93
+ # prefetch_mode='cuda': dataloader for CUDAPrefetcher
94
+ return torch.utils.data.DataLoader(**dataloader_args)
95
+
96
+
97
+ def worker_init_fn(worker_id, num_workers, rank, seed):
98
+ # Set the worker seed to num_workers * rank + worker_id + seed
99
+ worker_seed = num_workers * rank + worker_id + seed
100
+ np.random.seed(worker_seed)
101
+ random.seed(worker_seed)
basicsr/data/data_sampler.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.utils.data.sampler import Sampler
4
+
5
+
6
+ class EnlargedSampler(Sampler):
7
+ """Sampler that restricts data loading to a subset of the dataset.
8
+
9
+ Modified from torch.utils.data.distributed.DistributedSampler
10
+ Support enlarging the dataset for iteration-based training, for saving
11
+ time when restart the dataloader after each epoch
12
+
13
+ Args:
14
+ dataset (torch.utils.data.Dataset): Dataset used for sampling.
15
+ num_replicas (int | None): Number of processes participating in
16
+ the training. It is usually the world_size.
17
+ rank (int | None): Rank of the current process within num_replicas.
18
+ ratio (int): Enlarging ratio. Default: 1.
19
+ """
20
+
21
+ def __init__(self, dataset, num_replicas, rank, ratio=1):
22
+ self.dataset = dataset
23
+ self.num_replicas = num_replicas
24
+ self.rank = rank
25
+ self.epoch = 0
26
+ self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
27
+ self.total_size = self.num_samples * self.num_replicas
28
+
29
+ def __iter__(self):
30
+ # deterministically shuffle based on epoch
31
+ g = torch.Generator()
32
+ g.manual_seed(self.epoch)
33
+ indices = torch.randperm(self.total_size, generator=g).tolist()
34
+
35
+ dataset_size = len(self.dataset)
36
+ indices = [v % dataset_size for v in indices]
37
+
38
+ # subsample
39
+ indices = indices[self.rank:self.total_size:self.num_replicas]
40
+ assert len(indices) == self.num_samples
41
+
42
+ return iter(indices)
43
+
44
+ def __len__(self):
45
+ return self.num_samples
46
+
47
+ def set_epoch(self, epoch):
48
+ self.epoch = epoch
basicsr/data/data_util.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ from os import path as osp
5
+ from torch.nn import functional as F
6
+
7
+ from basicsr.data.transforms import mod_crop
8
+ from basicsr.utils import img2tensor, scandir
9
+
10
+
11
+ def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
12
+ """Read a sequence of images from a given folder path.
13
+
14
+ Args:
15
+ path (list[str] | str): List of image paths or image folder path.
16
+ require_mod_crop (bool): Require mod crop for each image.
17
+ Default: False.
18
+ scale (int): Scale factor for mod_crop. Default: 1.
19
+ return_imgname(bool): Whether return image names. Default False.
20
+
21
+ Returns:
22
+ Tensor: size (t, c, h, w), RGB, [0, 1].
23
+ list[str]: Returned image name list.
24
+ """
25
+ if isinstance(path, list):
26
+ img_paths = path
27
+ else:
28
+ img_paths = sorted(list(scandir(path, full_path=True)))
29
+ imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
30
+
31
+ if require_mod_crop:
32
+ imgs = [mod_crop(img, scale) for img in imgs]
33
+ imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
34
+ imgs = torch.stack(imgs, dim=0)
35
+
36
+ if return_imgname:
37
+ imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths]
38
+ return imgs, imgnames
39
+ else:
40
+ return imgs
41
+
42
+
43
+ def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
44
+ """Generate an index list for reading `num_frames` frames from a sequence
45
+ of images.
46
+
47
+ Args:
48
+ crt_idx (int): Current center index.
49
+ max_frame_num (int): Max number of the sequence of images (from 1).
50
+ num_frames (int): Reading num_frames frames.
51
+ padding (str): Padding mode, one of
52
+ 'replicate' | 'reflection' | 'reflection_circle' | 'circle'
53
+ Examples: current_idx = 0, num_frames = 5
54
+ The generated frame indices under different padding mode:
55
+ replicate: [0, 0, 0, 1, 2]
56
+ reflection: [2, 1, 0, 1, 2]
57
+ reflection_circle: [4, 3, 0, 1, 2]
58
+ circle: [3, 4, 0, 1, 2]
59
+
60
+ Returns:
61
+ list[int]: A list of indices.
62
+ """
63
+ assert num_frames % 2 == 1, 'num_frames should be an odd number.'
64
+ assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
65
+
66
+ max_frame_num = max_frame_num - 1 # start from 0
67
+ num_pad = num_frames // 2
68
+
69
+ indices = []
70
+ for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
71
+ if i < 0:
72
+ if padding == 'replicate':
73
+ pad_idx = 0
74
+ elif padding == 'reflection':
75
+ pad_idx = -i
76
+ elif padding == 'reflection_circle':
77
+ pad_idx = crt_idx + num_pad - i
78
+ else:
79
+ pad_idx = num_frames + i
80
+ elif i > max_frame_num:
81
+ if padding == 'replicate':
82
+ pad_idx = max_frame_num
83
+ elif padding == 'reflection':
84
+ pad_idx = max_frame_num * 2 - i
85
+ elif padding == 'reflection_circle':
86
+ pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
87
+ else:
88
+ pad_idx = i - num_frames
89
+ else:
90
+ pad_idx = i
91
+ indices.append(pad_idx)
92
+ return indices
93
+
94
+
95
+ def paired_paths_from_lmdb(folders, keys):
96
+ """Generate paired paths from lmdb files.
97
+
98
+ Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
99
+
100
+ lq.lmdb
101
+ ├── data.mdb
102
+ ├── lock.mdb
103
+ ├── meta_info.txt
104
+
105
+ The data.mdb and lock.mdb are standard lmdb files and you can refer to
106
+ https://lmdb.readthedocs.io/en/release/ for more details.
107
+
108
+ The meta_info.txt is a specified txt file to record the meta information
109
+ of our datasets. It will be automatically created when preparing
110
+ datasets by our provided dataset tools.
111
+ Each line in the txt file records
112
+ 1)image name (with extension),
113
+ 2)image shape,
114
+ 3)compression level, separated by a white space.
115
+ Example: `baboon.png (120,125,3) 1`
116
+
117
+ We use the image name without extension as the lmdb key.
118
+ Note that we use the same key for the corresponding lq and gt images.
119
+
120
+ Args:
121
+ folders (list[str]): A list of folder path. The order of list should
122
+ be [input_folder, gt_folder].
123
+ keys (list[str]): A list of keys identifying folders. The order should
124
+ be in consistent with folders, e.g., ['lq', 'gt'].
125
+ Note that this key is different from lmdb keys.
126
+
127
+ Returns:
128
+ list[str]: Returned path list.
129
+ """
130
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
131
+ f'But got {len(folders)}')
132
+ assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
133
+ input_folder, gt_folder = folders
134
+ input_key, gt_key = keys
135
+
136
+ if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
137
+ raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
138
+ f'formats. But received {input_key}: {input_folder}; '
139
+ f'{gt_key}: {gt_folder}')
140
+ # ensure that the two meta_info files are the same
141
+ with open(osp.join(input_folder, 'meta_info.txt')) as fin:
142
+ input_lmdb_keys = [line.split('.')[0] for line in fin]
143
+ with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
144
+ gt_lmdb_keys = [line.split('.')[0] for line in fin]
145
+ if set(input_lmdb_keys) != set(gt_lmdb_keys):
146
+ raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
147
+ else:
148
+ paths = []
149
+ for lmdb_key in sorted(input_lmdb_keys):
150
+ paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
151
+ return paths
152
+
153
+
154
+ def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
155
+ """Generate paired paths from an meta information file.
156
+
157
+ Each line in the meta information file contains the image names and
158
+ image shape (usually for gt), separated by a white space.
159
+
160
+ Example of an meta information file:
161
+ ```
162
+ 0001_s001.png (480,480,3)
163
+ 0001_s002.png (480,480,3)
164
+ ```
165
+
166
+ Args:
167
+ folders (list[str]): A list of folder path. The order of list should
168
+ be [input_folder, gt_folder].
169
+ keys (list[str]): A list of keys identifying folders. The order should
170
+ be in consistent with folders, e.g., ['lq', 'gt'].
171
+ meta_info_file (str): Path to the meta information file.
172
+ filename_tmpl (str): Template for each filename. Note that the
173
+ template excludes the file extension. Usually the filename_tmpl is
174
+ for files in the input folder.
175
+
176
+ Returns:
177
+ list[str]: Returned path list.
178
+ """
179
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
180
+ f'But got {len(folders)}')
181
+ assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
182
+ input_folder, gt_folder = folders
183
+ input_key, gt_key = keys
184
+
185
+ with open(meta_info_file, 'r') as fin:
186
+ gt_names = [line.split(' ')[0] for line in fin]
187
+
188
+ paths = []
189
+ for gt_name in gt_names:
190
+ basename, ext = osp.splitext(osp.basename(gt_name))
191
+ input_name = f'{filename_tmpl.format(basename)}{ext}'
192
+ input_path = osp.join(input_folder, input_name)
193
+ gt_path = osp.join(gt_folder, gt_name)
194
+ paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
195
+ return paths
196
+
197
+
198
+ def paired_paths_from_folder(folders, keys, filename_tmpl):
199
+ """Generate paired paths from folders.
200
+
201
+ Args:
202
+ folders (list[str]): A list of folder path. The order of list should
203
+ be [input_folder, gt_folder].
204
+ keys (list[str]): A list of keys identifying folders. The order should
205
+ be in consistent with folders, e.g., ['lq', 'gt'].
206
+ filename_tmpl (str): Template for each filename. Note that the
207
+ template excludes the file extension. Usually the filename_tmpl is
208
+ for files in the input folder.
209
+
210
+ Returns:
211
+ list[str]: Returned path list.
212
+ """
213
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
214
+ f'But got {len(folders)}')
215
+ assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
216
+ input_folder, gt_folder = folders
217
+ input_key, gt_key = keys
218
+
219
+ input_paths = list(scandir(input_folder))
220
+ gt_paths = list(scandir(gt_folder))
221
+ assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
222
+ f'{len(input_paths)}, {len(gt_paths)}.')
223
+ paths = []
224
+ for gt_path in gt_paths:
225
+ basename, ext = osp.splitext(osp.basename(gt_path))
226
+ input_name = f'{filename_tmpl.format(basename)}{ext}'
227
+ input_path = osp.join(input_folder, input_name)
228
+ assert input_name in input_paths, f'{input_name} is not in {input_key}_paths.'
229
+ gt_path = osp.join(gt_folder, gt_path)
230
+ paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
231
+ return paths
232
+
233
+
234
+ def paths_from_folder(folder):
235
+ """Generate paths from folder.
236
+
237
+ Args:
238
+ folder (str): Folder path.
239
+
240
+ Returns:
241
+ list[str]: Returned path list.
242
+ """
243
+
244
+ paths = list(scandir(folder))
245
+ paths = [osp.join(folder, path) for path in paths]
246
+ return paths
247
+
248
+
249
+ def paths_from_lmdb(folder):
250
+ """Generate paths from lmdb.
251
+
252
+ Args:
253
+ folder (str): Folder path.
254
+
255
+ Returns:
256
+ list[str]: Returned path list.
257
+ """
258
+ if not folder.endswith('.lmdb'):
259
+ raise ValueError(f'Folder {folder}folder should in lmdb format.')
260
+ with open(osp.join(folder, 'meta_info.txt')) as fin:
261
+ paths = [line.split('.')[0] for line in fin]
262
+ return paths
263
+
264
+
265
+ def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
266
+ """Generate Gaussian kernel used in `duf_downsample`.
267
+
268
+ Args:
269
+ kernel_size (int): Kernel size. Default: 13.
270
+ sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
271
+
272
+ Returns:
273
+ np.array: The Gaussian kernel.
274
+ """
275
+ from scipy.ndimage import filters as filters
276
+ kernel = np.zeros((kernel_size, kernel_size))
277
+ # set element at the middle to one, a dirac delta
278
+ kernel[kernel_size // 2, kernel_size // 2] = 1
279
+ # gaussian-smooth the dirac, resulting in a gaussian filter
280
+ return filters.gaussian_filter(kernel, sigma)
281
+
282
+
283
+ def duf_downsample(x, kernel_size=13, scale=4):
284
+ """Downsamping with Gaussian kernel used in the DUF official code.
285
+
286
+ Args:
287
+ x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
288
+ kernel_size (int): Kernel size. Default: 13.
289
+ scale (int): Downsampling factor. Supported scale: (2, 3, 4).
290
+ Default: 4.
291
+
292
+ Returns:
293
+ Tensor: DUF downsampled frames.
294
+ """
295
+ assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
296
+
297
+ squeeze_flag = False
298
+ if x.ndim == 4:
299
+ squeeze_flag = True
300
+ x = x.unsqueeze(0)
301
+ b, t, c, h, w = x.size()
302
+ x = x.view(-1, 1, h, w)
303
+ pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
304
+ x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
305
+
306
+ gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
307
+ gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
308
+ x = F.conv2d(x, gaussian_filter, stride=scale)
309
+ x = x[:, :, 2:-2, 2:-2]
310
+ x = x.view(b, t, c, x.size(2), x.size(3))
311
+ if squeeze_flag:
312
+ x = x.squeeze(0)
313
+ return x
basicsr/data/fmix.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Fmix paper from arxiv: https://arxiv.org/abs/2002.12047
3
+ Fmix code from github : https://github.com/ecs-vlc/FMix
4
+ '''
5
+ import math
6
+ import random
7
+ import numpy as np
8
+ from scipy.stats import beta
9
+
10
+
11
+ def fftfreqnd(h, w=None, z=None):
12
+ """ Get bin values for discrete fourier transform of size (h, w, z)
13
+ :param h: Required, first dimension size
14
+ :param w: Optional, second dimension size
15
+ :param z: Optional, third dimension size
16
+ """
17
+ fz = fx = 0
18
+ fy = np.fft.fftfreq(h)
19
+
20
+ if w is not None:
21
+ fy = np.expand_dims(fy, -1)
22
+
23
+ if w % 2 == 1:
24
+ fx = np.fft.fftfreq(w)[: w // 2 + 2]
25
+ else:
26
+ fx = np.fft.fftfreq(w)[: w // 2 + 1]
27
+
28
+ if z is not None:
29
+ fy = np.expand_dims(fy, -1)
30
+ if z % 2 == 1:
31
+ fz = np.fft.fftfreq(z)[:, None]
32
+ else:
33
+ fz = np.fft.fftfreq(z)[:, None]
34
+
35
+ return np.sqrt(fx * fx + fy * fy + fz * fz)
36
+
37
+
38
+ def get_spectrum(freqs, decay_power, ch, h, w=0, z=0):
39
+ """ Samples a fourier image with given size and frequencies decayed by decay power
40
+ :param freqs: Bin values for the discrete fourier transform
41
+ :param decay_power: Decay power for frequency decay prop 1/f**d
42
+ :param ch: Number of channels for the resulting mask
43
+ :param h: Required, first dimension size
44
+ :param w: Optional, second dimension size
45
+ :param z: Optional, third dimension size
46
+ """
47
+ scale = np.ones(1) / (np.maximum(freqs, np.array([1. / max(w, h, z)])) ** decay_power)
48
+
49
+ param_size = [ch] + list(freqs.shape) + [2]
50
+ param = np.random.randn(*param_size)
51
+
52
+ scale = np.expand_dims(scale, -1)[None, :]
53
+
54
+ return scale * param
55
+
56
+
57
+ def make_low_freq_image(decay, shape, ch=1):
58
+ """ Sample a low frequency image from fourier space
59
+ :param decay_power: Decay power for frequency decay prop 1/f**d
60
+ :param shape: Shape of desired mask, list up to 3 dims
61
+ :param ch: Number of channels for desired mask
62
+ """
63
+ freqs = fftfreqnd(*shape)
64
+ spectrum = get_spectrum(freqs, decay, ch, *shape)#.reshape((1, *shape[:-1], -1))
65
+ spectrum = spectrum[:, 0] + 1j * spectrum[:, 1]
66
+ mask = np.real(np.fft.irfftn(spectrum, shape))
67
+
68
+ if len(shape) == 1:
69
+ mask = mask[:1, :shape[0]]
70
+ if len(shape) == 2:
71
+ mask = mask[:1, :shape[0], :shape[1]]
72
+ if len(shape) == 3:
73
+ mask = mask[:1, :shape[0], :shape[1], :shape[2]]
74
+
75
+ mask = mask
76
+ mask = (mask - mask.min())
77
+ mask = mask / mask.max()
78
+ return mask
79
+
80
+
81
+ def sample_lam(alpha, reformulate=False):
82
+ """ Sample a lambda from symmetric beta distribution with given alpha
83
+ :param alpha: Alpha value for beta distribution
84
+ :param reformulate: If True, uses the reformulation of [1].
85
+ """
86
+ if reformulate:
87
+ lam = beta.rvs(alpha+1, alpha) # rvs(arg1,arg2,loc=期望, scale=标准差, size=生成随机数的个数) 从分布中生成指定个数的随机数
88
+ else:
89
+ lam = beta.rvs(alpha, alpha) # rvs(arg1,arg2,loc=期望, scale=标准差, size=生成随机数的个数) 从分布中生成指定个数的随机数
90
+
91
+ return lam
92
+
93
+
94
+ def binarise_mask(mask, lam, in_shape, max_soft=0.0):
95
+ """ Binarises a given low frequency image such that it has mean lambda.
96
+ :param mask: Low frequency image, usually the result of `make_low_freq_image`
97
+ :param lam: Mean value of final mask
98
+ :param in_shape: Shape of inputs
99
+ :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
100
+ :return:
101
+ """
102
+ idx = mask.reshape(-1).argsort()[::-1]
103
+ mask = mask.reshape(-1)
104
+ num = math.ceil(lam * mask.size) if random.random() > 0.5 else math.floor(lam * mask.size)
105
+
106
+ eff_soft = max_soft
107
+ if max_soft > lam or max_soft > (1-lam):
108
+ eff_soft = min(lam, 1-lam)
109
+
110
+ soft = int(mask.size * eff_soft)
111
+ num_low = num - soft
112
+ num_high = num + soft
113
+
114
+ mask[idx[:num_high]] = 1
115
+ mask[idx[num_low:]] = 0
116
+ mask[idx[num_low:num_high]] = np.linspace(1, 0, (num_high - num_low))
117
+
118
+ mask = mask.reshape((1, *in_shape))
119
+ return mask
120
+
121
+
122
+ def sample_mask(alpha, decay_power, shape, max_soft=0.0, reformulate=False):
123
+ """ Samples a mean lambda from beta distribution parametrised by alpha, creates a low frequency image and binarises
124
+ it based on this lambda
125
+ :param alpha: Alpha value for beta distribution from which to sample mean of mask
126
+ :param decay_power: Decay power for frequency decay prop 1/f**d
127
+ :param shape: Shape of desired mask, list up to 3 dims
128
+ :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
129
+ :param reformulate: If True, uses the reformulation of [1].
130
+ """
131
+ if isinstance(shape, int):
132
+ shape = (shape,)
133
+
134
+ # Choose lambda
135
+ lam = sample_lam(alpha, reformulate)
136
+
137
+ # Make mask, get mean / std
138
+ mask = make_low_freq_image(decay_power, shape)
139
+ mask = binarise_mask(mask, lam, shape, max_soft)
140
+
141
+ return lam, mask
142
+
143
+
144
+ def sample_and_apply(x, alpha, decay_power, shape, max_soft=0.0, reformulate=False):
145
+ """
146
+ :param x: Image batch on which to apply fmix of shape [b, c, shape*]
147
+ :param alpha: Alpha value for beta distribution from which to sample mean of mask
148
+ :param decay_power: Decay power for frequency decay prop 1/f**d
149
+ :param shape: Shape of desired mask, list up to 3 dims
150
+ :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
151
+ :param reformulate: If True, uses the reformulation of [1].
152
+ :return: mixed input, permutation indices, lambda value of mix,
153
+ """
154
+ lam, mask = sample_mask(alpha, decay_power, shape, max_soft, reformulate)
155
+ index = np.random.permutation(x.shape[0])
156
+
157
+ x1, x2 = x * mask, x[index] * (1-mask)
158
+ return x1+x2, index, lam
159
+
160
+
161
+ class FMixBase:
162
+ """ FMix augmentation
163
+ Args:
164
+ decay_power (float): Decay power for frequency decay prop 1/f**d
165
+ alpha (float): Alpha value for beta distribution from which to sample mean of mask
166
+ size ([int] | [int, int] | [int, int, int]): Shape of desired mask, list up to 3 dims
167
+ max_soft (float): Softening value between 0 and 0.5 which smooths hard edges in the mask.
168
+ reformulate (bool): If True, uses the reformulation of [1].
169
+ """
170
+
171
+ def __init__(self, decay_power=3, alpha=1, size=(32, 32), max_soft=0.0, reformulate=False):
172
+ super().__init__()
173
+ self.decay_power = decay_power
174
+ self.reformulate = reformulate
175
+ self.size = size
176
+ self.alpha = alpha
177
+ self.max_soft = max_soft
178
+ self.index = None
179
+ self.lam = None
180
+
181
+ def __call__(self, x):
182
+ raise NotImplementedError
183
+
184
+ def loss(self, *args, **kwargs):
185
+ raise NotImplementedError
186
+
187
+
188
+ if __name__ == '__main__':
189
+ # para = {'alpha':1.,'decay_power':3.,'shape':(10,10),'max_soft':0.0,'reformulate':False}
190
+ # lam, mask = sample_mask(**para)
191
+ # mask = mask.transpose(1, 2, 0)
192
+ # img1 = np.zeros((10, 10, 3))
193
+ # img2 = np.ones((10, 10, 3))
194
+ # img_gt = mask * img1 + (1. - mask) * img2
195
+ # import ipdb; ipdb.set_trace()
196
+
197
+ # test
198
+ import cv2
199
+ i1 = cv2.imread('output/ILSVRC2012_val_00000001.JPEG')
200
+ i2 = cv2.imread('output/ILSVRC2012_val_00000002.JPEG')
201
+ para = {'alpha':1.,'decay_power':3.,'shape':(256, 256),'max_soft':0.0,'reformulate':False}
202
+ lam, mask = sample_mask(**para)
203
+ mask = mask.transpose(1, 2, 0)
204
+ i = mask * i1 + (1. - mask) * i2
205
+ #i = i.astype(np.uint8)
206
+ cv2.imwrite('fmix.jpg', i)
basicsr/data/lab_dataset.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import random
3
+ import time
4
+ import numpy as np
5
+ import torch
6
+ from torch.utils import data as data
7
+
8
+ from basicsr.data.transforms import rgb2lab
9
+ from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
10
+ from basicsr.utils.registry import DATASET_REGISTRY
11
+ from basicsr.data.fmix import sample_mask
12
+
13
+
14
+ @DATASET_REGISTRY.register()
15
+ class LabDataset(data.Dataset):
16
+ """
17
+ Dataset used for Lab colorizaion
18
+ """
19
+
20
+ def __init__(self, opt):
21
+ super(LabDataset, self).__init__()
22
+ self.opt = opt
23
+ # file client (io backend)
24
+ self.file_client = None
25
+ self.io_backend_opt = opt['io_backend']
26
+ self.gt_folder = opt['dataroot_gt']
27
+
28
+ meta_info_file = self.opt['meta_info_file']
29
+ assert meta_info_file is not None
30
+ if not isinstance(meta_info_file, list):
31
+ meta_info_file = [meta_info_file]
32
+ self.paths = []
33
+ for meta_info in meta_info_file:
34
+ with open(meta_info, 'r') as fin:
35
+ self.paths.extend([line.strip() for line in fin])
36
+
37
+ self.min_ab, self.max_ab = -128, 128
38
+ self.interval_ab = 4
39
+ self.ab_palette = [i for i in range(self.min_ab, self.max_ab + self.interval_ab, self.interval_ab)]
40
+ # print(self.ab_palette)
41
+
42
+ self.do_fmix = opt['do_fmix']
43
+ self.fmix_params = {'alpha':1.,'decay_power':3.,'shape':(256,256),'max_soft':0.0,'reformulate':False}
44
+ self.fmix_p = opt['fmix_p']
45
+ self.do_cutmix = opt['do_cutmix']
46
+ self.cutmix_params = {'alpha':1.}
47
+ self.cutmix_p = opt['cutmix_p']
48
+
49
+
50
+ def __getitem__(self, index):
51
+ if self.file_client is None:
52
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
53
+
54
+ # -------------------------------- Load gt images -------------------------------- #
55
+ # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
56
+ gt_path = self.paths[index]
57
+ gt_size = self.opt['gt_size']
58
+ # avoid errors caused by high latency in reading files
59
+ retry = 3
60
+ while retry > 0:
61
+ try:
62
+ img_bytes = self.file_client.get(gt_path, 'gt')
63
+ except Exception as e:
64
+ logger = get_root_logger()
65
+ logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
66
+ # change another file to read
67
+ index = random.randint(0, self.__len__())
68
+ gt_path = self.paths[index]
69
+ time.sleep(1) # sleep 1s for occasional server congestion
70
+ else:
71
+ break
72
+ finally:
73
+ retry -= 1
74
+ img_gt = imfrombytes(img_bytes, float32=True)
75
+ img_gt = cv2.resize(img_gt, (gt_size, gt_size)) # TODO: 直接resize是否是最佳方案?
76
+
77
+ # -------------------------------- (Optional) CutMix & FMix -------------------------------- #
78
+ if self.do_fmix and np.random.uniform(0., 1., size=1)[0] > self.fmix_p:
79
+ with torch.no_grad():
80
+ lam, mask = sample_mask(**self.fmix_params)
81
+
82
+ fmix_index = random.randint(0, self.__len__())
83
+ fmix_img_path = self.paths[fmix_index]
84
+ fmix_img_bytes = self.file_client.get(fmix_img_path, 'gt')
85
+ fmix_img = imfrombytes(fmix_img_bytes, float32=True)
86
+ fmix_img = cv2.resize(fmix_img, (gt_size, gt_size))
87
+
88
+ mask = mask.transpose(1, 2, 0) # (1, 256, 256) -> # (256, 256, 1)
89
+ img_gt = mask * img_gt + (1. - mask) * fmix_img
90
+ img_gt = img_gt.astype(np.float32)
91
+
92
+ if self.do_cutmix and np.random.uniform(0., 1., size=1)[0] > self.cutmix_p:
93
+ with torch.no_grad():
94
+ cmix_index = random.randint(0, self.__len__())
95
+ cmix_img_path = self.paths[cmix_index]
96
+ cmix_img_bytes = self.file_client.get(cmix_img_path, 'gt')
97
+ cmix_img = imfrombytes(cmix_img_bytes, float32=True)
98
+ cmix_img = cv2.resize(cmix_img, (gt_size, gt_size))
99
+
100
+ lam = np.clip(np.random.beta(self.cutmix_params['alpha'], self.cutmix_params['alpha']), 0.3, 0.4)
101
+ bbx1, bby1, bbx2, bby2 = rand_bbox(cmix_img.shape[:2], lam)
102
+
103
+ img_gt[:, bbx1:bbx2, bby1:bby2] = cmix_img[:, bbx1:bbx2, bby1:bby2]
104
+
105
+
106
+ # ----------------------------- Get gray lq, to tentor ----------------------------- #
107
+ # convert to gray
108
+ img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2RGB)
109
+ img_l, img_ab = rgb2lab(img_gt)
110
+
111
+ target_a, target_b = self.ab2int(img_ab)
112
+
113
+ # numpy to tensor
114
+ img_l, img_ab = img2tensor([img_l, img_ab], bgr2rgb=False, float32=True)
115
+ target_a, target_b = torch.LongTensor(target_a), torch.LongTensor(target_b)
116
+ return_d = {
117
+ 'lq': img_l,
118
+ 'gt': img_ab,
119
+ 'target_a': target_a,
120
+ 'target_b': target_b,
121
+ 'lq_path': gt_path,
122
+ 'gt_path': gt_path
123
+ }
124
+ return return_d
125
+
126
+ def ab2int(self, img_ab):
127
+ img_a, img_b = img_ab[:, :, 0], img_ab[:, :, 1]
128
+ int_a = (img_a - self.min_ab) / self.interval_ab
129
+ int_b = (img_b - self.min_ab) / self.interval_ab
130
+
131
+ return np.round(int_a), np.round(int_b)
132
+
133
+ def __len__(self):
134
+ return len(self.paths)
135
+
136
+
137
+ def rand_bbox(size, lam):
138
+ '''cutmix 的 bbox 截取函数
139
+ Args:
140
+ size : tuple 图片尺寸 e.g (256,256)
141
+ lam : float 截取比例
142
+ Returns:
143
+ bbox 的左上角和右下角坐标
144
+ int,int,int,int
145
+ '''
146
+ W = size[0] # 截取图片的宽度
147
+ H = size[1] # 截取图片的高度
148
+ cut_rat = np.sqrt(1. - lam) # 需要截取的 bbox 比例
149
+ cut_w = np.int(W * cut_rat) # 需要截取的 bbox 宽度
150
+ cut_h = np.int(H * cut_rat) # 需要截取的 bbox 高度
151
+
152
+ cx = np.random.randint(W) # 均匀分布采样,随机选择截取的 bbox 的中心点 x 坐标
153
+ cy = np.random.randint(H) # 均匀分布采样,随机选择截取的 bbox 的中心点 y 坐标
154
+
155
+ bbx1 = np.clip(cx - cut_w // 2, 0, W) # 左上角 x 坐标
156
+ bby1 = np.clip(cy - cut_h // 2, 0, H) # 左上角 y 坐标
157
+ bbx2 = np.clip(cx + cut_w // 2, 0, W) # 右下角 x 坐标
158
+ bby2 = np.clip(cy + cut_h // 2, 0, H) # 右下角 y 坐标
159
+ return bbx1, bby1, bbx2, bby2
basicsr/data/prefetch_dataloader.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import queue as Queue
2
+ import threading
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+
6
+
7
+ class PrefetchGenerator(threading.Thread):
8
+ """A general prefetch generator.
9
+
10
+ Ref:
11
+ https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
12
+
13
+ Args:
14
+ generator: Python generator.
15
+ num_prefetch_queue (int): Number of prefetch queue.
16
+ """
17
+
18
+ def __init__(self, generator, num_prefetch_queue):
19
+ threading.Thread.__init__(self)
20
+ self.queue = Queue.Queue(num_prefetch_queue)
21
+ self.generator = generator
22
+ self.daemon = True
23
+ self.start()
24
+
25
+ def run(self):
26
+ for item in self.generator:
27
+ self.queue.put(item)
28
+ self.queue.put(None)
29
+
30
+ def __next__(self):
31
+ next_item = self.queue.get()
32
+ if next_item is None:
33
+ raise StopIteration
34
+ return next_item
35
+
36
+ def __iter__(self):
37
+ return self
38
+
39
+
40
+ class PrefetchDataLoader(DataLoader):
41
+ """Prefetch version of dataloader.
42
+
43
+ Ref:
44
+ https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
45
+
46
+ TODO:
47
+ Need to test on single gpu and ddp (multi-gpu). There is a known issue in
48
+ ddp.
49
+
50
+ Args:
51
+ num_prefetch_queue (int): Number of prefetch queue.
52
+ kwargs (dict): Other arguments for dataloader.
53
+ """
54
+
55
+ def __init__(self, num_prefetch_queue, **kwargs):
56
+ self.num_prefetch_queue = num_prefetch_queue
57
+ super(PrefetchDataLoader, self).__init__(**kwargs)
58
+
59
+ def __iter__(self):
60
+ return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
61
+
62
+
63
+ class CPUPrefetcher():
64
+ """CPU prefetcher.
65
+
66
+ Args:
67
+ loader: Dataloader.
68
+ """
69
+
70
+ def __init__(self, loader):
71
+ self.ori_loader = loader
72
+ self.loader = iter(loader)
73
+
74
+ def next(self):
75
+ try:
76
+ return next(self.loader)
77
+ except StopIteration:
78
+ return None
79
+
80
+ def reset(self):
81
+ self.loader = iter(self.ori_loader)
82
+
83
+
84
+ class CUDAPrefetcher():
85
+ """CUDA prefetcher.
86
+
87
+ Ref:
88
+ https://github.com/NVIDIA/apex/issues/304#
89
+
90
+ It may consums more GPU memory.
91
+
92
+ Args:
93
+ loader: Dataloader.
94
+ opt (dict): Options.
95
+ """
96
+
97
+ def __init__(self, loader, opt):
98
+ self.ori_loader = loader
99
+ self.loader = iter(loader)
100
+ self.opt = opt
101
+ self.stream = torch.cuda.Stream()
102
+ self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
103
+ self.preload()
104
+
105
+ def preload(self):
106
+ try:
107
+ self.batch = next(self.loader) # self.batch is a dict
108
+ except StopIteration:
109
+ self.batch = None
110
+ return None
111
+ # put tensors to gpu
112
+ with torch.cuda.stream(self.stream):
113
+ for k, v in self.batch.items():
114
+ if torch.is_tensor(v):
115
+ self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
116
+
117
+ def next(self):
118
+ torch.cuda.current_stream().wait_stream(self.stream)
119
+ batch = self.batch
120
+ self.preload()
121
+ return batch
122
+
123
+ def reset(self):
124
+ self.loader = iter(self.ori_loader)
125
+ self.preload()
basicsr/data/transforms.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ import random
5
+ from scipy import special
6
+ from skimage import color
7
+
8
+
9
+ def mod_crop(img, scale):
10
+ """Mod crop images, used during testing.
11
+
12
+ Args:
13
+ img (ndarray): Input image.
14
+ scale (int): Scale factor.
15
+
16
+ Returns:
17
+ ndarray: Result image.
18
+ """
19
+ img = img.copy()
20
+ if img.ndim in (2, 3):
21
+ h, w = img.shape[0], img.shape[1]
22
+ h_remainder, w_remainder = h % scale, w % scale
23
+ img = img[:h - h_remainder, :w - w_remainder, ...]
24
+ else:
25
+ raise ValueError(f'Wrong img ndim: {img.ndim}.')
26
+ return img
27
+
28
+
29
+ def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
30
+ """Paired random crop. Support Numpy array and Tensor inputs.
31
+
32
+ It crops lists of lq and gt images with corresponding locations.
33
+
34
+ Args:
35
+ img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
36
+ should have the same shape. If the input is an ndarray, it will
37
+ be transformed to a list containing itself.
38
+ img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
39
+ should have the same shape. If the input is an ndarray, it will
40
+ be transformed to a list containing itself.
41
+ gt_patch_size (int): GT patch size.
42
+ scale (int): Scale factor.
43
+ gt_path (str): Path to ground-truth. Default: None.
44
+
45
+ Returns:
46
+ list[ndarray] | ndarray: GT images and LQ images. If returned results
47
+ only have one element, just return ndarray.
48
+ """
49
+
50
+ if not isinstance(img_gts, list):
51
+ img_gts = [img_gts]
52
+ if not isinstance(img_lqs, list):
53
+ img_lqs = [img_lqs]
54
+
55
+ # determine input type: Numpy array or Tensor
56
+ input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
57
+
58
+ if input_type == 'Tensor':
59
+ h_lq, w_lq = img_lqs[0].size()[-2:]
60
+ h_gt, w_gt = img_gts[0].size()[-2:]
61
+ else:
62
+ h_lq, w_lq = img_lqs[0].shape[0:2]
63
+ h_gt, w_gt = img_gts[0].shape[0:2]
64
+ lq_patch_size = gt_patch_size // scale
65
+
66
+ if h_gt != h_lq * scale or w_gt != w_lq * scale:
67
+ raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
68
+ f'multiplication of LQ ({h_lq}, {w_lq}).')
69
+ if h_lq < lq_patch_size or w_lq < lq_patch_size:
70
+ raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
71
+ f'({lq_patch_size}, {lq_patch_size}). '
72
+ f'Please remove {gt_path}.')
73
+
74
+ # randomly choose top and left coordinates for lq patch
75
+ top = random.randint(0, h_lq - lq_patch_size)
76
+ left = random.randint(0, w_lq - lq_patch_size)
77
+
78
+ # crop lq patch
79
+ if input_type == 'Tensor':
80
+ img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
81
+ else:
82
+ img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
83
+
84
+ # crop corresponding gt patch
85
+ top_gt, left_gt = int(top * scale), int(left * scale)
86
+ if input_type == 'Tensor':
87
+ img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
88
+ else:
89
+ img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
90
+ if len(img_gts) == 1:
91
+ img_gts = img_gts[0]
92
+ if len(img_lqs) == 1:
93
+ img_lqs = img_lqs[0]
94
+ return img_gts, img_lqs
95
+
96
+
97
+ def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
98
+ """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
99
+
100
+ We use vertical flip and transpose for rotation implementation.
101
+ All the images in the list use the same augmentation.
102
+
103
+ Args:
104
+ imgs (list[ndarray] | ndarray): Images to be augmented. If the input
105
+ is an ndarray, it will be transformed to a list.
106
+ hflip (bool): Horizontal flip. Default: True.
107
+ rotation (bool): Ratotation. Default: True.
108
+ flows (list[ndarray]: Flows to be augmented. If the input is an
109
+ ndarray, it will be transformed to a list.
110
+ Dimension is (h, w, 2). Default: None.
111
+ return_status (bool): Return the status of flip and rotation.
112
+ Default: False.
113
+
114
+ Returns:
115
+ list[ndarray] | ndarray: Augmented images and flows. If returned
116
+ results only have one element, just return ndarray.
117
+
118
+ """
119
+ hflip = hflip and random.random() < 0.5
120
+ vflip = rotation and random.random() < 0.5
121
+ rot90 = rotation and random.random() < 0.5
122
+
123
+ def _augment(img):
124
+ if hflip: # horizontal
125
+ cv2.flip(img, 1, img)
126
+ if vflip: # vertical
127
+ cv2.flip(img, 0, img)
128
+ if rot90:
129
+ img = img.transpose(1, 0, 2)
130
+ return img
131
+
132
+ def _augment_flow(flow):
133
+ if hflip: # horizontal
134
+ cv2.flip(flow, 1, flow)
135
+ flow[:, :, 0] *= -1
136
+ if vflip: # vertical
137
+ cv2.flip(flow, 0, flow)
138
+ flow[:, :, 1] *= -1
139
+ if rot90:
140
+ flow = flow.transpose(1, 0, 2)
141
+ flow = flow[:, :, [1, 0]]
142
+ return flow
143
+
144
+ if not isinstance(imgs, list):
145
+ imgs = [imgs]
146
+ imgs = [_augment(img) for img in imgs]
147
+ if len(imgs) == 1:
148
+ imgs = imgs[0]
149
+
150
+ if flows is not None:
151
+ if not isinstance(flows, list):
152
+ flows = [flows]
153
+ flows = [_augment_flow(flow) for flow in flows]
154
+ if len(flows) == 1:
155
+ flows = flows[0]
156
+ return imgs, flows
157
+ else:
158
+ if return_status:
159
+ return imgs, (hflip, vflip, rot90)
160
+ else:
161
+ return imgs
162
+
163
+
164
+ def img_rotate(img, angle, center=None, scale=1.0, borderMode=cv2.BORDER_CONSTANT, borderValue=0.):
165
+ """Rotate image.
166
+
167
+ Args:
168
+ img (ndarray): Image to be rotated.
169
+ angle (float): Rotation angle in degrees. Positive values mean
170
+ counter-clockwise rotation.
171
+ center (tuple[int]): Rotation center. If the center is None,
172
+ initialize it as the center of the image. Default: None.
173
+ scale (float): Isotropic scale factor. Default: 1.0.
174
+ """
175
+ (h, w) = img.shape[:2]
176
+
177
+ if center is None:
178
+ center = (w // 2, h // 2)
179
+
180
+ matrix = cv2.getRotationMatrix2D(center, angle, scale)
181
+ rotated_img = cv2.warpAffine(img, matrix, (w, h), borderMode=borderMode, borderValue=borderValue)
182
+ return rotated_img
183
+
184
+
185
+ def rgb2lab(img_rgb):
186
+ img_lab = color.rgb2lab(img_rgb)
187
+ img_l = img_lab[:, :, :1]
188
+ img_ab = img_lab[:, :, 1:]
189
+ return img_l, img_ab
190
+
191
+
192
+
basicsr/losses/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+
3
+ from basicsr.utils import get_root_logger
4
+ from basicsr.utils.registry import LOSS_REGISTRY
5
+ from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, PerceptualLoss, WeightedTVLoss, g_path_regularize,
6
+ gradient_penalty_loss, r1_penalty)
7
+
8
+ __all__ = [
9
+ 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss',
10
+ 'r1_penalty', 'g_path_regularize'
11
+ ]
12
+
13
+
14
+ def build_loss(opt):
15
+ """Build loss from options.
16
+
17
+ Args:
18
+ opt (dict): Configuration. It must contain:
19
+ type (str): Model type.
20
+ """
21
+ opt = deepcopy(opt)
22
+ loss_type = opt.pop('type')
23
+ loss = LOSS_REGISTRY.get(loss_type)(**opt)
24
+ logger = get_root_logger()
25
+ logger.info(f'Loss [{loss.__class__.__name__}] is created.')
26
+ return loss
basicsr/losses/loss_util.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ from torch.nn import functional as F
3
+
4
+
5
+ def reduce_loss(loss, reduction):
6
+ """Reduce loss as specified.
7
+
8
+ Args:
9
+ loss (Tensor): Elementwise loss tensor.
10
+ reduction (str): Options are 'none', 'mean' and 'sum'.
11
+
12
+ Returns:
13
+ Tensor: Reduced loss tensor.
14
+ """
15
+ reduction_enum = F._Reduction.get_enum(reduction)
16
+ # none: 0, elementwise_mean:1, sum: 2
17
+ if reduction_enum == 0:
18
+ return loss
19
+ elif reduction_enum == 1:
20
+ return loss.mean()
21
+ else:
22
+ return loss.sum()
23
+
24
+
25
+ def weight_reduce_loss(loss, weight=None, reduction='mean'):
26
+ """Apply element-wise weight and reduce loss.
27
+
28
+ Args:
29
+ loss (Tensor): Element-wise loss.
30
+ weight (Tensor): Element-wise weights. Default: None.
31
+ reduction (str): Same as built-in losses of PyTorch. Options are
32
+ 'none', 'mean' and 'sum'. Default: 'mean'.
33
+
34
+ Returns:
35
+ Tensor: Loss values.
36
+ """
37
+ # if weight is specified, apply element-wise weight
38
+ if weight is not None:
39
+ assert weight.dim() == loss.dim()
40
+ assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
41
+ loss = loss * weight
42
+
43
+ # if weight is not specified or reduction is sum, just reduce the loss
44
+ if weight is None or reduction == 'sum':
45
+ loss = reduce_loss(loss, reduction)
46
+ # if reduction is mean, then compute mean over weight region
47
+ elif reduction == 'mean':
48
+ if weight.size(1) > 1:
49
+ weight = weight.sum()
50
+ else:
51
+ weight = weight.sum() * loss.size(1)
52
+ loss = loss.sum() / weight
53
+
54
+ return loss
55
+
56
+
57
+ def weighted_loss(loss_func):
58
+ """Create a weighted version of a given loss function.
59
+
60
+ To use this decorator, the loss function must have the signature like
61
+ `loss_func(pred, target, **kwargs)`. The function only needs to compute
62
+ element-wise loss without any reduction. This decorator will add weight
63
+ and reduction arguments to the function. The decorated function will have
64
+ the signature like `loss_func(pred, target, weight=None, reduction='mean',
65
+ **kwargs)`.
66
+
67
+ :Example:
68
+
69
+ >>> import torch
70
+ >>> @weighted_loss
71
+ >>> def l1_loss(pred, target):
72
+ >>> return (pred - target).abs()
73
+
74
+ >>> pred = torch.Tensor([0, 2, 3])
75
+ >>> target = torch.Tensor([1, 1, 1])
76
+ >>> weight = torch.Tensor([1, 0, 1])
77
+
78
+ >>> l1_loss(pred, target)
79
+ tensor(1.3333)
80
+ >>> l1_loss(pred, target, weight)
81
+ tensor(1.5000)
82
+ >>> l1_loss(pred, target, reduction='none')
83
+ tensor([1., 1., 2.])
84
+ >>> l1_loss(pred, target, weight, reduction='sum')
85
+ tensor(3.)
86
+ """
87
+
88
+ @functools.wraps(loss_func)
89
+ def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
90
+ # get element-wise loss
91
+ loss = loss_func(pred, target, **kwargs)
92
+ loss = weight_reduce_loss(loss, weight, reduction)
93
+ return loss
94
+
95
+ return wrapper
basicsr/losses/losses.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import autograd as autograd
4
+ from torch import nn as nn
5
+ from torch.nn import functional as F
6
+
7
+ from basicsr.archs.vgg_arch import VGGFeatureExtractor
8
+ from basicsr.utils.registry import LOSS_REGISTRY
9
+ from .loss_util import weighted_loss
10
+
11
+ _reduction_modes = ['none', 'mean', 'sum']
12
+
13
+
14
+ @weighted_loss
15
+ def l1_loss(pred, target):
16
+ return F.l1_loss(pred, target, reduction='none')
17
+
18
+
19
+ @weighted_loss
20
+ def mse_loss(pred, target):
21
+ return F.mse_loss(pred, target, reduction='none')
22
+
23
+
24
+ @weighted_loss
25
+ def charbonnier_loss(pred, target, eps=1e-12):
26
+ return torch.sqrt((pred - target)**2 + eps)
27
+
28
+
29
+ @LOSS_REGISTRY.register()
30
+ class L1Loss(nn.Module):
31
+ """L1 (mean absolute error, MAE) loss.
32
+
33
+ Args:
34
+ loss_weight (float): Loss weight for L1 loss. Default: 1.0.
35
+ reduction (str): Specifies the reduction to apply to the output.
36
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
37
+ """
38
+
39
+ def __init__(self, loss_weight=1.0, reduction='mean'):
40
+ super(L1Loss, self).__init__()
41
+ if reduction not in ['none', 'mean', 'sum']:
42
+ raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
43
+
44
+ self.loss_weight = loss_weight
45
+ self.reduction = reduction
46
+
47
+ def forward(self, pred, target, weight=None, **kwargs):
48
+ """
49
+ Args:
50
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
51
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
52
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
53
+ weights. Default: None.
54
+ """
55
+ return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)
56
+
57
+
58
+ @LOSS_REGISTRY.register()
59
+ class MSELoss(nn.Module):
60
+ """MSE (L2) loss.
61
+
62
+ Args:
63
+ loss_weight (float): Loss weight for MSE loss. Default: 1.0.
64
+ reduction (str): Specifies the reduction to apply to the output.
65
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
66
+ """
67
+
68
+ def __init__(self, loss_weight=1.0, reduction='mean'):
69
+ super(MSELoss, self).__init__()
70
+ if reduction not in ['none', 'mean', 'sum']:
71
+ raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
72
+
73
+ self.loss_weight = loss_weight
74
+ self.reduction = reduction
75
+
76
+ def forward(self, pred, target, weight=None, **kwargs):
77
+ """
78
+ Args:
79
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
80
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
81
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
82
+ weights. Default: None.
83
+ """
84
+ return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction)
85
+
86
+
87
+ @LOSS_REGISTRY.register()
88
+ class CharbonnierLoss(nn.Module):
89
+ """Charbonnier loss (one variant of Robust L1Loss, a differentiable
90
+ variant of L1Loss).
91
+
92
+ Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
93
+ Super-Resolution".
94
+
95
+ Args:
96
+ loss_weight (float): Loss weight for L1 loss. Default: 1.0.
97
+ reduction (str): Specifies the reduction to apply to the output.
98
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
99
+ eps (float): A value used to control the curvature near zero.
100
+ Default: 1e-12.
101
+ """
102
+
103
+ def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
104
+ super(CharbonnierLoss, self).__init__()
105
+ if reduction not in ['none', 'mean', 'sum']:
106
+ raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
107
+
108
+ self.loss_weight = loss_weight
109
+ self.reduction = reduction
110
+ self.eps = eps
111
+
112
+ def forward(self, pred, target, weight=None, **kwargs):
113
+ """
114
+ Args:
115
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
116
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
117
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
118
+ weights. Default: None.
119
+ """
120
+ return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)
121
+
122
+
123
+ @LOSS_REGISTRY.register()
124
+ class WeightedTVLoss(L1Loss):
125
+ """Weighted TV loss.
126
+
127
+ Args:
128
+ loss_weight (float): Loss weight. Default: 1.0.
129
+ """
130
+
131
+ def __init__(self, loss_weight=1.0):
132
+ super(WeightedTVLoss, self).__init__(loss_weight=loss_weight)
133
+
134
+ def forward(self, pred, weight=None):
135
+ if weight is None:
136
+ y_weight = None
137
+ x_weight = None
138
+ else:
139
+ y_weight = weight[:, :, :-1, :]
140
+ x_weight = weight[:, :, :, :-1]
141
+
142
+ y_diff = super(WeightedTVLoss, self).forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=y_weight)
143
+ x_diff = super(WeightedTVLoss, self).forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=x_weight)
144
+
145
+ loss = x_diff + y_diff
146
+
147
+ return loss
148
+
149
+
150
+ @LOSS_REGISTRY.register()
151
+ class PerceptualLoss(nn.Module):
152
+ """Perceptual loss with commonly used style loss.
153
+
154
+ Args:
155
+ layer_weights (dict): The weight for each layer of vgg feature.
156
+ Here is an example: {'conv5_4': 1.}, which means the conv5_4
157
+ feature layer (before relu5_4) will be extracted with weight
158
+ 1.0 in calculating losses.
159
+ vgg_type (str): The type of vgg network used as feature extractor.
160
+ Default: 'vgg19'.
161
+ use_input_norm (bool): If True, normalize the input image in vgg.
162
+ Default: True.
163
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
164
+ Default: False.
165
+ perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
166
+ loss will be calculated and the loss will multiplied by the
167
+ weight. Default: 1.0.
168
+ style_weight (float): If `style_weight > 0`, the style loss will be
169
+ calculated and the loss will multiplied by the weight.
170
+ Default: 0.
171
+ criterion (str): Criterion used for perceptual loss. Default: 'l1'.
172
+ """
173
+
174
+ def __init__(self,
175
+ layer_weights,
176
+ vgg_type='vgg19',
177
+ use_input_norm=True,
178
+ range_norm=False,
179
+ perceptual_weight=1.0,
180
+ style_weight=0.,
181
+ criterion='l1'):
182
+ super(PerceptualLoss, self).__init__()
183
+ self.perceptual_weight = perceptual_weight
184
+ self.style_weight = style_weight
185
+ self.layer_weights = layer_weights
186
+ self.vgg = VGGFeatureExtractor(
187
+ layer_name_list=list(layer_weights.keys()),
188
+ vgg_type=vgg_type,
189
+ use_input_norm=use_input_norm,
190
+ range_norm=range_norm)
191
+
192
+ self.criterion_type = criterion
193
+ if self.criterion_type == 'l1':
194
+ self.criterion = torch.nn.L1Loss()
195
+ elif self.criterion_type == 'l2':
196
+ self.criterion = torch.nn.L2loss()
197
+ elif self.criterion_type == 'fro':
198
+ self.criterion = None
199
+ else:
200
+ raise NotImplementedError(f'{criterion} criterion has not been supported.')
201
+
202
+ def forward(self, x, gt):
203
+ """Forward function.
204
+
205
+ Args:
206
+ x (Tensor): Input tensor with shape (n, c, h, w).
207
+ gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
208
+
209
+ Returns:
210
+ Tensor: Forward results.
211
+ """
212
+ # extract vgg features
213
+ x_features = self.vgg(x)
214
+ gt_features = self.vgg(gt.detach())
215
+
216
+ # calculate perceptual loss
217
+ if self.perceptual_weight > 0:
218
+ percep_loss = 0
219
+ for k in x_features.keys():
220
+ if self.criterion_type == 'fro':
221
+ percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
222
+ else:
223
+ percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
224
+ percep_loss *= self.perceptual_weight
225
+ else:
226
+ percep_loss = None
227
+
228
+ # calculate style loss
229
+ if self.style_weight > 0:
230
+ style_loss = 0
231
+ for k in x_features.keys():
232
+ if self.criterion_type == 'fro':
233
+ style_loss += torch.norm(
234
+ self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
235
+ else:
236
+ style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
237
+ gt_features[k])) * self.layer_weights[k]
238
+ style_loss *= self.style_weight
239
+ else:
240
+ style_loss = None
241
+
242
+ return percep_loss, style_loss
243
+
244
+ def _gram_mat(self, x):
245
+ """Calculate Gram matrix.
246
+
247
+ Args:
248
+ x (torch.Tensor): Tensor with shape of (n, c, h, w).
249
+
250
+ Returns:
251
+ torch.Tensor: Gram matrix.
252
+ """
253
+ n, c, h, w = x.size()
254
+ features = x.view(n, c, w * h)
255
+ features_t = features.transpose(1, 2)
256
+ gram = features.bmm(features_t) / (c * h * w)
257
+ return gram
258
+
259
+
260
+ @LOSS_REGISTRY.register()
261
+ class GANLoss(nn.Module):
262
+ """Define GAN loss.
263
+
264
+ Args:
265
+ gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
266
+ real_label_val (float): The value for real label. Default: 1.0.
267
+ fake_label_val (float): The value for fake label. Default: 0.0.
268
+ loss_weight (float): Loss weight. Default: 1.0.
269
+ Note that loss_weight is only for generators; and it is always 1.0
270
+ for discriminators.
271
+ """
272
+
273
+ def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
274
+ super(GANLoss, self).__init__()
275
+ self.gan_type = gan_type
276
+ self.loss_weight = loss_weight
277
+ self.real_label_val = real_label_val
278
+ self.fake_label_val = fake_label_val
279
+
280
+ if self.gan_type == 'vanilla':
281
+ self.loss = nn.BCEWithLogitsLoss()
282
+ elif self.gan_type == 'lsgan':
283
+ self.loss = nn.MSELoss()
284
+ elif self.gan_type == 'wgan':
285
+ self.loss = self._wgan_loss
286
+ elif self.gan_type == 'wgan_softplus':
287
+ self.loss = self._wgan_softplus_loss
288
+ elif self.gan_type == 'hinge':
289
+ self.loss = nn.ReLU()
290
+ else:
291
+ raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')
292
+
293
+ def _wgan_loss(self, input, target):
294
+ """wgan loss.
295
+
296
+ Args:
297
+ input (Tensor): Input tensor.
298
+ target (bool): Target label.
299
+
300
+ Returns:
301
+ Tensor: wgan loss.
302
+ """
303
+ return -input.mean() if target else input.mean()
304
+
305
+ def _wgan_softplus_loss(self, input, target):
306
+ """wgan loss with soft plus. softplus is a smooth approximation to the
307
+ ReLU function.
308
+
309
+ In StyleGAN2, it is called:
310
+ Logistic loss for discriminator;
311
+ Non-saturating loss for generator.
312
+
313
+ Args:
314
+ input (Tensor): Input tensor.
315
+ target (bool): Target label.
316
+
317
+ Returns:
318
+ Tensor: wgan loss.
319
+ """
320
+ return F.softplus(-input).mean() if target else F.softplus(input).mean()
321
+
322
+ def get_target_label(self, input, target_is_real):
323
+ """Get target label.
324
+
325
+ Args:
326
+ input (Tensor): Input tensor.
327
+ target_is_real (bool): Whether the target is real or fake.
328
+
329
+ Returns:
330
+ (bool | Tensor): Target tensor. Return bool for wgan, otherwise,
331
+ return Tensor.
332
+ """
333
+
334
+ if self.gan_type in ['wgan', 'wgan_softplus']:
335
+ return target_is_real
336
+ target_val = (self.real_label_val if target_is_real else self.fake_label_val)
337
+ return input.new_ones(input.size()) * target_val
338
+
339
+ def forward(self, input, target_is_real, is_disc=False):
340
+ """
341
+ Args:
342
+ input (Tensor): The input for the loss module, i.e., the network
343
+ prediction.
344
+ target_is_real (bool): Whether the targe is real or fake.
345
+ is_disc (bool): Whether the loss for discriminators or not.
346
+ Default: False.
347
+
348
+ Returns:
349
+ Tensor: GAN loss value.
350
+ """
351
+ target_label = self.get_target_label(input, target_is_real)
352
+ if self.gan_type == 'hinge':
353
+ if is_disc: # for discriminators in hinge-gan
354
+ input = -input if target_is_real else input
355
+ loss = self.loss(1 + input).mean()
356
+ else: # for generators in hinge-gan
357
+ loss = -input.mean()
358
+ else: # other gan types
359
+ loss = self.loss(input, target_label)
360
+
361
+ # loss_weight is always 1.0 for discriminators
362
+ return loss if is_disc else loss * self.loss_weight
363
+
364
+
365
+ @LOSS_REGISTRY.register()
366
+ class MultiScaleGANLoss(GANLoss):
367
+ """
368
+ MultiScaleGANLoss accepts a list of predictions
369
+ """
370
+
371
+ def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
372
+ super(MultiScaleGANLoss, self).__init__(gan_type, real_label_val, fake_label_val, loss_weight)
373
+
374
+ def forward(self, input, target_is_real, is_disc=False):
375
+ """
376
+ The input is a list of tensors, or a list of (a list of tensors)
377
+ """
378
+ if isinstance(input, list):
379
+ loss = 0
380
+ for pred_i in input:
381
+ if isinstance(pred_i, list):
382
+ # Only compute GAN loss for the last layer
383
+ # in case of multiscale feature matching
384
+ pred_i = pred_i[-1]
385
+ # Safe operation: 0-dim tensor calling self.mean() does nothing
386
+ loss_tensor = super().forward(pred_i, target_is_real, is_disc).mean()
387
+ loss += loss_tensor
388
+ return loss / len(input)
389
+ else:
390
+ return super().forward(input, target_is_real, is_disc)
391
+
392
+
393
+ def r1_penalty(real_pred, real_img):
394
+ """R1 regularization for discriminator. The core idea is to
395
+ penalize the gradient on real data alone: when the
396
+ generator distribution produces the true data distribution
397
+ and the discriminator is equal to 0 on the data manifold, the
398
+ gradient penalty ensures that the discriminator cannot create
399
+ a non-zero gradient orthogonal to the data manifold without
400
+ suffering a loss in the GAN game.
401
+
402
+ Ref:
403
+ Eq. 9 in Which training methods for GANs do actually converge.
404
+ """
405
+ grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
406
+ grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
407
+ return grad_penalty
408
+
409
+
410
+ def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
411
+ noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
412
+ grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
413
+ path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
414
+
415
+ path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
416
+
417
+ path_penalty = (path_lengths - path_mean).pow(2).mean()
418
+
419
+ return path_penalty, path_lengths.detach().mean(), path_mean.detach()
420
+
421
+
422
+ def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
423
+ """Calculate gradient penalty for wgan-gp.
424
+
425
+ Args:
426
+ discriminator (nn.Module): Network for the discriminator.
427
+ real_data (Tensor): Real input data.
428
+ fake_data (Tensor): Fake input data.
429
+ weight (Tensor): Weight tensor. Default: None.
430
+
431
+ Returns:
432
+ Tensor: A tensor for gradient penalty.
433
+ """
434
+
435
+ batch_size = real_data.size(0)
436
+ alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
437
+
438
+ # interpolate between real_data and fake_data
439
+ interpolates = alpha * real_data + (1. - alpha) * fake_data
440
+ interpolates = autograd.Variable(interpolates, requires_grad=True)
441
+
442
+ disc_interpolates = discriminator(interpolates)
443
+ gradients = autograd.grad(
444
+ outputs=disc_interpolates,
445
+ inputs=interpolates,
446
+ grad_outputs=torch.ones_like(disc_interpolates),
447
+ create_graph=True,
448
+ retain_graph=True,
449
+ only_inputs=True)[0]
450
+
451
+ if weight is not None:
452
+ gradients = gradients * weight
453
+
454
+ gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
455
+ if weight is not None:
456
+ gradients_penalty /= torch.mean(weight)
457
+
458
+ return gradients_penalty
459
+
460
+
461
+ @LOSS_REGISTRY.register()
462
+ class GANFeatLoss(nn.Module):
463
+ """Define feature matching loss for gans
464
+
465
+ Args:
466
+ criterion (str): Support 'l1', 'l2', 'charbonnier'.
467
+ loss_weight (float): Loss weight. Default: 1.0.
468
+ reduction (str): Specifies the reduction to apply to the output.
469
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
470
+ """
471
+
472
+ def __init__(self, criterion='l1', loss_weight=1.0, reduction='mean'):
473
+ super(GANFeatLoss, self).__init__()
474
+ if criterion == 'l1':
475
+ self.loss_op = L1Loss(loss_weight, reduction)
476
+ elif criterion == 'l2':
477
+ self.loss_op = MSELoss(loss_weight, reduction)
478
+ elif criterion == 'charbonnier':
479
+ self.loss_op = CharbonnierLoss(loss_weight, reduction)
480
+ else:
481
+ raise ValueError(f'Unsupported loss mode: {criterion}. Supported ones are: l1|l2|charbonnier')
482
+
483
+ self.loss_weight = loss_weight
484
+
485
+ def forward(self, pred_fake, pred_real):
486
+ num_d = len(pred_fake)
487
+ loss = 0
488
+ for i in range(num_d): # for each discriminator
489
+ # last output is the final prediction, exclude it
490
+ num_intermediate_outputs = len(pred_fake[i]) - 1
491
+ for j in range(num_intermediate_outputs): # for each layer output
492
+ unweighted_loss = self.loss_op(pred_fake[i][j], pred_real[i][j].detach())
493
+ loss += unweighted_loss / num_d
494
+ return loss * self.loss_weight
495
+
496
+
497
+ class sobel_loss(nn.Module):
498
+ def __init__(self, weight=1.0):
499
+ super().__init__()
500
+ kernel_x = torch.Tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]])
501
+ kernel_y = torch.Tensor([[-1.0, -2.0, -1.0], [0.0, 0.0, 0.0], [1.0, 2.0, 1.0]])
502
+ kernel = torch.stack([kernel_x, kernel_y])
503
+ kernel.requires_grad = False
504
+ kernel = kernel.unsqueeze(1)
505
+ self.register_buffer('sobel_kernel', kernel)
506
+ self.weight = weight
507
+
508
+ def forward(self, input_tensor, target_tensor):
509
+ b, c, h, w = input_tensor.size()
510
+ input_tensor = input_tensor.view(b * c, 1, h, w)
511
+ input_edge = F.conv2d(input_tensor, self.sobel_kernel, padding=1)
512
+ input_edge = input_edge.view(b, 2*c, h, w)
513
+
514
+ target_tensor = target_tensor.view(-1, 1, h, w)
515
+ target_edge = F.conv2d(target_tensor, self.sobel_kernel, padding=1)
516
+ target_edge = target_edge.view(b, 2*c, h, w)
517
+
518
+ return self.weight * F.l1_loss(input_edge, target_edge)
519
+
520
+
521
+ @LOSS_REGISTRY.register()
522
+ class ColorfulnessLoss(nn.Module):
523
+ """Colorfulness loss.
524
+
525
+ Args:
526
+ loss_weight (float): Loss weight for Colorfulness loss. Default: 1.0.
527
+
528
+ """
529
+
530
+ def __init__(self, loss_weight=1.0):
531
+ super(ColorfulnessLoss, self).__init__()
532
+
533
+ self.loss_weight = loss_weight
534
+
535
+ def forward(self, pred, **kwargs):
536
+ """
537
+ Args:
538
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
539
+ """
540
+ colorfulness_loss = 0
541
+ for i in range(pred.shape[0]):
542
+ (R, G, B) = pred[i][0], pred[i][1], pred[i][2]
543
+ rg = torch.abs(R - G)
544
+ yb = torch.abs(0.5 * (R+G) - B)
545
+ (rbMean, rbStd) = (torch.mean(rg), torch.std(rg))
546
+ (ybMean, ybStd) = (torch.mean(yb), torch.std(yb))
547
+ stdRoot = torch.sqrt((rbStd ** 2) + (ybStd ** 2))
548
+ meanRoot = torch.sqrt((rbMean ** 2) + (ybMean ** 2))
549
+ colorfulness = stdRoot + (0.3 * meanRoot)
550
+ colorfulness_loss += (1 - colorfulness)
551
+ return self.loss_weight * colorfulness_loss
basicsr/metrics/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+
3
+ from basicsr.utils.registry import METRIC_REGISTRY
4
+ from .psnr_ssim import calculate_psnr, calculate_ssim
5
+ from .colorfulness import calculate_cf
6
+
7
+ __all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_cf']
8
+
9
+
10
+ def calculate_metric(data, opt):
11
+ """Calculate metric from data and options.
12
+
13
+ Args:
14
+ opt (dict): Configuration. It must contain:
15
+ type (str): Model type.
16
+ """
17
+ opt = deepcopy(opt)
18
+ metric_type = opt.pop('type')
19
+ metric = METRIC_REGISTRY.get(metric_type)(**data, **opt)
20
+ return metric
basicsr/metrics/colorfulness.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from basicsr.utils.registry import METRIC_REGISTRY
4
+
5
+
6
+ @METRIC_REGISTRY.register()
7
+ def calculate_cf(img, **kwargs):
8
+ """Calculate Colorfulness.
9
+ """
10
+ (B, G, R) = cv2.split(img.astype('float'))
11
+ rg = np.absolute(R - G)
12
+ yb = np.absolute(0.5 * (R+G) - B)
13
+ (rbMean, rbStd) = (np.mean(rg), np.std(rg))
14
+ (ybMean, ybStd) = (np.mean(yb), np.std(yb))
15
+ stdRoot = np.sqrt((rbStd ** 2) + (ybStd ** 2))
16
+ meanRoot = np.sqrt((rbMean ** 2) + (ybMean ** 2))
17
+ return stdRoot + (0.3 * meanRoot)
basicsr/metrics/custom_fid.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch.nn.functional as F
3
+ from scipy import linalg
4
+ import torch.nn as nn
5
+ from torchvision import models
6
+
7
+
8
+ class INCEPTION_V3_FID(nn.Module):
9
+ """pretrained InceptionV3 network returning feature maps"""
10
+ # Index of default block of inception to return,
11
+ # corresponds to output of final average pooling
12
+ DEFAULT_BLOCK_INDEX = 3
13
+
14
+ # Maps feature dimensionality to their output blocks indices
15
+ BLOCK_INDEX_BY_DIM = {
16
+ 64: 0, # First max pooling features
17
+ 192: 1, # Second max pooling featurs
18
+ 768: 2, # Pre-aux classifier features
19
+ 2048: 3 # Final average pooling features
20
+ }
21
+
22
+ def __init__(self,
23
+ incep_state_dict,
24
+ output_blocks=[DEFAULT_BLOCK_INDEX],
25
+ resize_input=True):
26
+ """Build pretrained InceptionV3
27
+ Parameters
28
+ ----------
29
+ output_blocks : list of int
30
+ Indices of blocks to return features of. Possible values are:
31
+ - 0: corresponds to output of first max pooling
32
+ - 1: corresponds to output of second max pooling
33
+ - 2: corresponds to output which is fed to aux classifier
34
+ - 3: corresponds to output of final average pooling
35
+ resize_input : bool
36
+ If true, bilinearly resizes input to width and height 299 before
37
+ feeding input to model. As the network without fully connected
38
+ layers is fully convolutional, it should be able to handle inputs
39
+ of arbitrary size, so resizing might not be strictly needed
40
+ normalize_input : bool
41
+ If true, normalizes the input to the statistics the pretrained
42
+ Inception network expects
43
+ """
44
+ super(INCEPTION_V3_FID, self).__init__()
45
+
46
+ self.resize_input = resize_input
47
+ self.output_blocks = sorted(output_blocks)
48
+ self.last_needed_block = max(output_blocks)
49
+
50
+ assert self.last_needed_block <= 3, \
51
+ 'Last possible output block index is 3'
52
+
53
+ self.blocks = nn.ModuleList()
54
+
55
+ inception = models.inception_v3()
56
+ inception.load_state_dict(incep_state_dict)
57
+ for param in inception.parameters():
58
+ param.requires_grad = False
59
+
60
+ # Block 0: input to maxpool1
61
+ block0 = [
62
+ inception.Conv2d_1a_3x3,
63
+ inception.Conv2d_2a_3x3,
64
+ inception.Conv2d_2b_3x3,
65
+ nn.MaxPool2d(kernel_size=3, stride=2)
66
+ ]
67
+ self.blocks.append(nn.Sequential(*block0))
68
+
69
+ # Block 1: maxpool1 to maxpool2
70
+ if self.last_needed_block >= 1:
71
+ block1 = [
72
+ inception.Conv2d_3b_1x1,
73
+ inception.Conv2d_4a_3x3,
74
+ nn.MaxPool2d(kernel_size=3, stride=2)
75
+ ]
76
+ self.blocks.append(nn.Sequential(*block1))
77
+
78
+ # Block 2: maxpool2 to aux classifier
79
+ if self.last_needed_block >= 2:
80
+ block2 = [
81
+ inception.Mixed_5b,
82
+ inception.Mixed_5c,
83
+ inception.Mixed_5d,
84
+ inception.Mixed_6a,
85
+ inception.Mixed_6b,
86
+ inception.Mixed_6c,
87
+ inception.Mixed_6d,
88
+ inception.Mixed_6e,
89
+ ]
90
+ self.blocks.append(nn.Sequential(*block2))
91
+
92
+ # Block 3: aux classifier to final avgpool
93
+ if self.last_needed_block >= 3:
94
+ block3 = [
95
+ inception.Mixed_7a,
96
+ inception.Mixed_7b,
97
+ inception.Mixed_7c,
98
+ nn.AdaptiveAvgPool2d(output_size=(1, 1))
99
+ ]
100
+ self.blocks.append(nn.Sequential(*block3))
101
+
102
+ def forward(self, inp):
103
+ """Get Inception feature maps
104
+ Parameters
105
+ ----------
106
+ inp : torch.autograd.Variable
107
+ Input tensor of shape Bx3xHxW. Values are expected to be in
108
+ range (0, 1)
109
+ Returns
110
+ -------
111
+ List of torch.autograd.Variable, corresponding to the selected output
112
+ block, sorted ascending by index
113
+ """
114
+ outp = []
115
+ x = inp
116
+
117
+ if self.resize_input:
118
+ x = F.interpolate(x, size=(299, 299), mode='bilinear')
119
+
120
+ x = x.clone()
121
+ # [-1.0, 1.0] --> [0, 1.0]
122
+ x = x * 0.5 + 0.5
123
+ x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
124
+ x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
125
+ x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
126
+
127
+ for idx, block in enumerate(self.blocks):
128
+ x = block(x)
129
+ if idx in self.output_blocks:
130
+ outp.append(x)
131
+
132
+ if idx == self.last_needed_block:
133
+ break
134
+
135
+ return outp
136
+
137
+
138
+ def get_activations(images, model, batch_size, verbose=False):
139
+ """Calculates the activations of the pool_3 layer for all images.
140
+ Params:
141
+ -- images : Numpy array of dimension (n_images, 3, hi, wi). The values
142
+ must lie between 0 and 1.
143
+ -- model : Instance of inception model
144
+ -- batch_size : the images numpy array is split into batches with
145
+ batch size batch_size. A reasonable batch size depends
146
+ on the hardware.
147
+ -- verbose : If set to True and parameter out_step is given, the number
148
+ of calculated batches is reported.
149
+ Returns:
150
+ -- A numpy array of dimension (num images, dims) that contains the
151
+ activations of the given tensor when feeding inception with the
152
+ query tensor.
153
+ """
154
+ model.eval()
155
+
156
+ #d0 = images.shape[0]
157
+ d0 = int(images.size(0))
158
+ if batch_size > d0:
159
+ print(('Warning: batch size is bigger than the data size. '
160
+ 'Setting batch size to data size'))
161
+ batch_size = d0
162
+
163
+ n_batches = d0 // batch_size
164
+ n_used_imgs = n_batches * batch_size
165
+
166
+ pred_arr = np.empty((n_used_imgs, 2048))
167
+ for i in range(n_batches):
168
+ if verbose:
169
+ print('\rPropagating batch %d/%d' % (i + 1, n_batches), end='', flush=True)
170
+ start = i * batch_size
171
+ end = start + batch_size
172
+
173
+ '''batch = torch.from_numpy(images[start:end]).type(torch.FloatTensor)
174
+ batch = Variable(batch, volatile=True)
175
+ if cfg.CUDA:
176
+ batch = batch.cuda()'''
177
+ batch = images[start:end]
178
+
179
+ pred = model(batch)[0]
180
+
181
+ # If model output is not scalar, apply global spatial average pooling.
182
+ # This happens if you choose a dimensionality not equal 2048.
183
+ if pred.shape[2] != 1 or pred.shape[3] != 1:
184
+ pred = F.adaptive_avg_pool2d(pred, output_size=(1, 1))
185
+
186
+ pred_arr[start:end] = pred.cpu().data.numpy().reshape(batch_size, -1)
187
+
188
+ if verbose:
189
+ print(' done')
190
+
191
+ return pred_arr
192
+
193
+
194
+ def calculate_activation_statistics(act):
195
+ """Calculation of the statistics used by the FID.
196
+ Params:
197
+ -- act : Numpy array of dimension (n_images, dim (e.g. 2048)).
198
+ Returns:
199
+ -- mu : The mean over samples of the activations of the pool_3 layer of
200
+ the inception model.
201
+ -- sigma : The covariance matrix of the activations of the pool_3 layer of
202
+ the inception model.
203
+ """
204
+ mu = np.mean(act, axis=0)
205
+ sigma = np.cov(act, rowvar=False)
206
+ return mu, sigma
207
+
208
+
209
+ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
210
+ """Numpy implementation of the Frechet Distance.
211
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
212
+ and X_2 ~ N(mu_2, C_2) is
213
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
214
+ Stable version by Dougal J. Sutherland.
215
+ Params:
216
+ -- mu1 : Numpy array containing the activations of a layer of the
217
+ inception net (like returned by the function 'get_predictions')
218
+ for generated samples.
219
+ -- mu2 : The sample mean over activations, precalculated on an
220
+ representive data set.
221
+ -- sigma1: The covariance matrix over activations for generated samples.
222
+ -- sigma2: The covariance matrix over activations, precalculated on an
223
+ representive data set.
224
+ Returns:
225
+ -- : The Frechet Distance.
226
+ """
227
+
228
+ mu1 = np.atleast_1d(mu1)
229
+ mu2 = np.atleast_1d(mu2)
230
+
231
+ sigma1 = np.atleast_2d(sigma1)
232
+ sigma2 = np.atleast_2d(sigma2)
233
+
234
+ assert mu1.shape == mu2.shape, \
235
+ 'Training and test mean vectors have different lengths'
236
+ assert sigma1.shape == sigma2.shape, \
237
+ 'Training and test covariances have different dimensions'
238
+
239
+ diff = mu1 - mu2
240
+
241
+ # Product might be almost singular
242
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
243
+ if not np.isfinite(covmean).all():
244
+ msg = ('fid calculation produces singular product; '
245
+ 'adding %s to diagonal of cov estimates') % eps
246
+ print(msg)
247
+ offset = np.eye(sigma1.shape[0]) * eps
248
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
249
+
250
+ # Numerical error might give slight imaginary component
251
+ if np.iscomplexobj(covmean):
252
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
253
+ m = np.max(np.abs(covmean.imag))
254
+ raise ValueError('Imaginary component {}'.format(m))
255
+ covmean = covmean.real
256
+
257
+ tr_covmean = np.trace(covmean)
258
+
259
+ return (diff.dot(diff) + np.trace(sigma1) +
260
+ np.trace(sigma2) - 2 * tr_covmean)
basicsr/metrics/metric_util.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from basicsr.utils.matlab_functions import bgr2ycbcr
4
+
5
+
6
+ def reorder_image(img, input_order='HWC'):
7
+ """Reorder images to 'HWC' order.
8
+
9
+ If the input_order is (h, w), return (h, w, 1);
10
+ If the input_order is (c, h, w), return (h, w, c);
11
+ If the input_order is (h, w, c), return as it is.
12
+
13
+ Args:
14
+ img (ndarray): Input image.
15
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
16
+ If the input image shape is (h, w), input_order will not have
17
+ effects. Default: 'HWC'.
18
+
19
+ Returns:
20
+ ndarray: reordered image.
21
+ """
22
+
23
+ if input_order not in ['HWC', 'CHW']:
24
+ raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' "'HWC' and 'CHW'")
25
+ if len(img.shape) == 2:
26
+ img = img[..., None]
27
+ if input_order == 'CHW':
28
+ img = img.transpose(1, 2, 0)
29
+ return img
30
+
31
+
32
+ def to_y_channel(img):
33
+ """Change to Y channel of YCbCr.
34
+
35
+ Args:
36
+ img (ndarray): Images with range [0, 255].
37
+
38
+ Returns:
39
+ (ndarray): Images with range [0, 255] (float type) without round.
40
+ """
41
+ img = img.astype(np.float32) / 255.
42
+ if img.ndim == 3 and img.shape[2] == 3:
43
+ img = bgr2ycbcr(img, y_only=True)
44
+ img = img[..., None]
45
+ return img * 255.
basicsr/metrics/psnr_ssim.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ from basicsr.metrics.metric_util import reorder_image, to_y_channel
5
+ from basicsr.utils.registry import METRIC_REGISTRY
6
+
7
+
8
+ @METRIC_REGISTRY.register()
9
+ def calculate_psnr(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs):
10
+ """Calculate PSNR (Peak Signal-to-Noise Ratio).
11
+
12
+ Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
13
+
14
+ Args:
15
+ img (ndarray): Images with range [0, 255].
16
+ img2 (ndarray): Images with range [0, 255].
17
+ crop_border (int): Cropped pixels in each edge of an image. These
18
+ pixels are not involved in the PSNR calculation.
19
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
20
+ Default: 'HWC'.
21
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
22
+
23
+ Returns:
24
+ float: psnr result.
25
+ """
26
+
27
+ assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
28
+ if input_order not in ['HWC', 'CHW']:
29
+ raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
30
+ img = reorder_image(img, input_order=input_order)
31
+ img2 = reorder_image(img2, input_order=input_order)
32
+ img = img.astype(np.float64)
33
+ img2 = img2.astype(np.float64)
34
+
35
+ if crop_border != 0:
36
+ img = img[crop_border:-crop_border, crop_border:-crop_border, ...]
37
+ img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
38
+
39
+ if test_y_channel:
40
+ img = to_y_channel(img)
41
+ img2 = to_y_channel(img2)
42
+
43
+ mse = np.mean((img - img2)**2)
44
+ if mse == 0:
45
+ return float('inf')
46
+ return 20. * np.log10(255. / np.sqrt(mse))
47
+
48
+
49
+ def _ssim(img, img2):
50
+ """Calculate SSIM (structural similarity) for one channel images.
51
+
52
+ It is called by func:`calculate_ssim`.
53
+
54
+ Args:
55
+ img (ndarray): Images with range [0, 255] with order 'HWC'.
56
+ img2 (ndarray): Images with range [0, 255] with order 'HWC'.
57
+
58
+ Returns:
59
+ float: ssim result.
60
+ """
61
+
62
+ c1 = (0.01 * 255)**2
63
+ c2 = (0.03 * 255)**2
64
+
65
+ img = img.astype(np.float64)
66
+ img2 = img2.astype(np.float64)
67
+ kernel = cv2.getGaussianKernel(11, 1.5)
68
+ window = np.outer(kernel, kernel.transpose())
69
+
70
+ mu1 = cv2.filter2D(img, -1, window)[5:-5, 5:-5]
71
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
72
+ mu1_sq = mu1**2
73
+ mu2_sq = mu2**2
74
+ mu1_mu2 = mu1 * mu2
75
+ sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq
76
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
77
+ sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
78
+
79
+ ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2))
80
+ return ssim_map.mean()
81
+
82
+
83
+ @METRIC_REGISTRY.register()
84
+ def calculate_ssim(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs):
85
+ """Calculate SSIM (structural similarity).
86
+
87
+ Ref:
88
+ Image quality assessment: From error visibility to structural similarity
89
+
90
+ The results are the same as that of the official released MATLAB code in
91
+ https://ece.uwaterloo.ca/~z70wang/research/ssim/.
92
+
93
+ For three-channel images, SSIM is calculated for each channel and then
94
+ averaged.
95
+
96
+ Args:
97
+ img (ndarray): Images with range [0, 255].
98
+ img2 (ndarray): Images with range [0, 255].
99
+ crop_border (int): Cropped pixels in each edge of an image. These
100
+ pixels are not involved in the SSIM calculation.
101
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
102
+ Default: 'HWC'.
103
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
104
+
105
+ Returns:
106
+ float: ssim result.
107
+ """
108
+
109
+ assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
110
+ if input_order not in ['HWC', 'CHW']:
111
+ raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
112
+ img = reorder_image(img, input_order=input_order)
113
+ img2 = reorder_image(img2, input_order=input_order)
114
+ img = img.astype(np.float64)
115
+ img2 = img2.astype(np.float64)
116
+
117
+ if crop_border != 0:
118
+ img = img[crop_border:-crop_border, crop_border:-crop_border, ...]
119
+ img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
120
+
121
+ if test_y_channel:
122
+ img = to_y_channel(img)
123
+ img2 = to_y_channel(img2)
124
+
125
+ ssims = []
126
+ for i in range(img.shape[2]):
127
+ ssims.append(_ssim(img[..., i], img2[..., i]))
128
+ return np.array(ssims).mean()
basicsr/models/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from copy import deepcopy
3
+ from os import path as osp
4
+
5
+ from basicsr.utils import get_root_logger, scandir
6
+ from basicsr.utils.registry import MODEL_REGISTRY
7
+
8
+ __all__ = ['build_model']
9
+
10
+ # automatically scan and import model modules for registry
11
+ # scan all the files under the 'models' folder and collect files ending with
12
+ # '_model.py'
13
+ model_folder = osp.dirname(osp.abspath(__file__))
14
+ model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
15
+ # import all the model modules
16
+ _model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames]
17
+
18
+
19
+ def build_model(opt):
20
+ """Build model from options.
21
+
22
+ Args:
23
+ opt (dict): Configuration. It must contain:
24
+ model_type (str): Model type.
25
+ """
26
+ opt = deepcopy(opt)
27
+ model = MODEL_REGISTRY.get(opt['model_type'])(opt)
28
+ logger = get_root_logger()
29
+ logger.info(f'Model [{model.__class__.__name__}] is created.')
30
+ return model
basicsr/models/base_model.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import torch
4
+ from collections import OrderedDict
5
+ from copy import deepcopy
6
+ from torch.nn.parallel import DataParallel, DistributedDataParallel
7
+
8
+ from basicsr.models import lr_scheduler as lr_scheduler
9
+ from basicsr.utils import get_root_logger
10
+ from basicsr.utils.dist_util import master_only
11
+
12
+
13
+ class BaseModel():
14
+ """Base model."""
15
+
16
+ def __init__(self, opt):
17
+ self.opt = opt
18
+ self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
19
+ self.is_train = opt['is_train']
20
+ self.schedulers = []
21
+ self.optimizers = []
22
+
23
+ def feed_data(self, data):
24
+ pass
25
+
26
+ def optimize_parameters(self):
27
+ pass
28
+
29
+ def get_current_visuals(self):
30
+ pass
31
+
32
+ def save(self, epoch, current_iter):
33
+ """Save networks and training state."""
34
+ pass
35
+
36
+ def validation(self, dataloader, current_iter, tb_logger, save_img=False):
37
+ """Validation function.
38
+
39
+ Args:
40
+ dataloader (torch.utils.data.DataLoader): Validation dataloader.
41
+ current_iter (int): Current iteration.
42
+ tb_logger (tensorboard logger): Tensorboard logger.
43
+ save_img (bool): Whether to save images. Default: False.
44
+ """
45
+ if self.opt['dist']:
46
+ self.dist_validation(dataloader, current_iter, tb_logger, save_img)
47
+ else:
48
+ self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
49
+
50
+ def _initialize_best_metric_results(self, dataset_name):
51
+ """Initialize the best metric results dict for recording the best metric value and iteration."""
52
+ if hasattr(self, 'best_metric_results') and dataset_name in self.best_metric_results:
53
+ return
54
+ elif not hasattr(self, 'best_metric_results'):
55
+ self.best_metric_results = dict()
56
+
57
+ # add a dataset record
58
+ record = dict()
59
+ for metric, content in self.opt['val']['metrics'].items():
60
+ better = content.get('better', 'higher')
61
+ init_val = float('-inf') if better == 'higher' else float('inf')
62
+ record[metric] = dict(better=better, val=init_val, iter=-1)
63
+ self.best_metric_results[dataset_name] = record
64
+
65
+ def _update_best_metric_result(self, dataset_name, metric, val, current_iter):
66
+ if self.best_metric_results[dataset_name][metric]['better'] == 'higher':
67
+ if val >= self.best_metric_results[dataset_name][metric]['val']:
68
+ self.best_metric_results[dataset_name][metric]['val'] = val
69
+ self.best_metric_results[dataset_name][metric]['iter'] = current_iter
70
+ else:
71
+ if val <= self.best_metric_results[dataset_name][metric]['val']:
72
+ self.best_metric_results[dataset_name][metric]['val'] = val
73
+ self.best_metric_results[dataset_name][metric]['iter'] = current_iter
74
+
75
+ def model_ema(self, decay=0.999):
76
+ net_g = self.get_bare_model(self.net_g)
77
+
78
+ net_g_params = dict(net_g.named_parameters())
79
+ net_g_ema_params = dict(self.net_g_ema.named_parameters())
80
+
81
+ for k in net_g_ema_params.keys():
82
+ net_g_ema_params[k].data.mul_(decay).add_(net_g_params[k].data, alpha=1 - decay)
83
+
84
+ def get_current_log(self):
85
+ return self.log_dict
86
+
87
+ def model_to_device(self, net):
88
+ """Model to device. It also warps models with DistributedDataParallel
89
+ or DataParallel.
90
+
91
+ Args:
92
+ net (nn.Module)
93
+ """
94
+ net = net.to(self.device)
95
+ if self.opt['dist']:
96
+ find_unused_parameters = self.opt.get('find_unused_parameters', False)
97
+ net = DistributedDataParallel(
98
+ net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters)
99
+ elif self.opt['num_gpu'] > 1:
100
+ net = DataParallel(net)
101
+ return net
102
+
103
+ def get_optimizer(self, optim_type, params, lr, **kwargs):
104
+ if optim_type == 'Adam':
105
+ optimizer = torch.optim.Adam(params, lr, **kwargs)
106
+ elif optim_type == 'AdamW':
107
+ optimizer = torch.optim.AdamW(params, lr, **kwargs)
108
+ else:
109
+ raise NotImplementedError(f'optimizer {optim_type} is not supperted yet.')
110
+ return optimizer
111
+
112
+ def setup_schedulers(self):
113
+ """Set up schedulers."""
114
+ train_opt = self.opt['train']
115
+ scheduler_type = train_opt['scheduler'].pop('type')
116
+ if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']:
117
+ for optimizer in self.optimizers:
118
+ self.schedulers.append(lr_scheduler.MultiStepRestartLR(optimizer, **train_opt['scheduler']))
119
+ elif scheduler_type == 'CosineAnnealingRestartLR':
120
+ for optimizer in self.optimizers:
121
+ self.schedulers.append(lr_scheduler.CosineAnnealingRestartLR(optimizer, **train_opt['scheduler']))
122
+ else:
123
+ raise NotImplementedError(f'Scheduler {scheduler_type} is not implemented yet.')
124
+
125
+ def get_bare_model(self, net):
126
+ """Get bare model, especially under wrapping with
127
+ DistributedDataParallel or DataParallel.
128
+ """
129
+ if isinstance(net, (DataParallel, DistributedDataParallel)):
130
+ net = net.module
131
+ return net
132
+
133
+ @master_only
134
+ def print_network(self, net):
135
+ """Print the str and parameter number of a network.
136
+
137
+ Args:
138
+ net (nn.Module)
139
+ """
140
+ if isinstance(net, (DataParallel, DistributedDataParallel)):
141
+ net_cls_str = f'{net.__class__.__name__} - {net.module.__class__.__name__}'
142
+ else:
143
+ net_cls_str = f'{net.__class__.__name__}'
144
+
145
+ net = self.get_bare_model(net)
146
+ net_str = str(net)
147
+ net_params = sum(map(lambda x: x.numel(), net.parameters()))
148
+
149
+ logger = get_root_logger()
150
+ logger.info(f'Network: {net_cls_str}, with parameters: {net_params:,d}')
151
+ logger.info(net_str)
152
+
153
+ def _set_lr(self, lr_groups_l):
154
+ """Set learning rate for warmup.
155
+
156
+ Args:
157
+ lr_groups_l (list): List for lr_groups, each for an optimizer.
158
+ """
159
+ for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
160
+ for param_group, lr in zip(optimizer.param_groups, lr_groups):
161
+ param_group['lr'] = lr
162
+
163
+ def _get_init_lr(self):
164
+ """Get the initial lr, which is set by the scheduler.
165
+ """
166
+ init_lr_groups_l = []
167
+ for optimizer in self.optimizers:
168
+ init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
169
+ return init_lr_groups_l
170
+
171
+ def update_learning_rate(self, current_iter, warmup_iter=-1):
172
+ """Update learning rate.
173
+
174
+ Args:
175
+ current_iter (int): Current iteration.
176
+ warmup_iter (int): Warmup iter numbers. -1 for no warmup.
177
+ Default: -1.
178
+ """
179
+ if current_iter > 1:
180
+ for scheduler in self.schedulers:
181
+ scheduler.step()
182
+ # set up warm-up learning rate
183
+ if current_iter < warmup_iter:
184
+ # get initial lr for each group
185
+ init_lr_g_l = self._get_init_lr()
186
+ # modify warming-up learning rates
187
+ # currently only support linearly warm up
188
+ warm_up_lr_l = []
189
+ for init_lr_g in init_lr_g_l:
190
+ warm_up_lr_l.append([v / warmup_iter * current_iter for v in init_lr_g])
191
+ # set learning rate
192
+ self._set_lr(warm_up_lr_l)
193
+
194
+ def get_current_learning_rate(self):
195
+ return [param_group['lr'] for param_group in self.optimizers[0].param_groups]
196
+
197
+ @master_only
198
+ def save_network(self, net, net_label, current_iter, param_key='params'):
199
+ """Save networks.
200
+
201
+ Args:
202
+ net (nn.Module | list[nn.Module]): Network(s) to be saved.
203
+ net_label (str): Network label.
204
+ current_iter (int): Current iter number.
205
+ param_key (str | list[str]): The parameter key(s) to save network.
206
+ Default: 'params'.
207
+ """
208
+ if current_iter == -1:
209
+ current_iter = 'latest'
210
+ save_filename = f'{net_label}_{current_iter}.pth'
211
+ save_path = os.path.join(self.opt['path']['models'], save_filename)
212
+
213
+ net = net if isinstance(net, list) else [net]
214
+ param_key = param_key if isinstance(param_key, list) else [param_key]
215
+ assert len(net) == len(param_key), 'The lengths of net and param_key should be the same.'
216
+
217
+ save_dict = {}
218
+ for net_, param_key_ in zip(net, param_key):
219
+ net_ = self.get_bare_model(net_)
220
+ state_dict = net_.state_dict()
221
+ for key, param in state_dict.items():
222
+ if key.startswith('module.'): # remove unnecessary 'module.'
223
+ key = key[7:]
224
+ state_dict[key] = param.cpu()
225
+ save_dict[param_key_] = state_dict
226
+
227
+ # avoid occasional writing errors
228
+ retry = 3
229
+ while retry > 0:
230
+ try:
231
+ torch.save(save_dict, save_path)
232
+ except Exception as e:
233
+ logger = get_root_logger()
234
+ logger.warning(f'Save model error: {e}, remaining retry times: {retry - 1}')
235
+ time.sleep(1)
236
+ else:
237
+ break
238
+ finally:
239
+ retry -= 1
240
+ if retry == 0:
241
+ logger.warning(f'Still cannot save {save_path}. Just ignore it.')
242
+ # raise IOError(f'Cannot save {save_path}.')
243
+
244
+ def _print_different_keys_loading(self, crt_net, load_net, strict=True):
245
+ """Print keys with different name or different size when loading models.
246
+
247
+ 1. Print keys with different names.
248
+ 2. If strict=False, print the same key but with different tensor size.
249
+ It also ignore these keys with different sizes (not load).
250
+
251
+ Args:
252
+ crt_net (torch model): Current network.
253
+ load_net (dict): Loaded network.
254
+ strict (bool): Whether strictly loaded. Default: True.
255
+ """
256
+ crt_net = self.get_bare_model(crt_net)
257
+ crt_net = crt_net.state_dict()
258
+ crt_net_keys = set(crt_net.keys())
259
+ load_net_keys = set(load_net.keys())
260
+
261
+ logger = get_root_logger()
262
+ if crt_net_keys != load_net_keys:
263
+ logger.warning('Current net - loaded net:')
264
+ for v in sorted(list(crt_net_keys - load_net_keys)):
265
+ logger.warning(f' {v}')
266
+ logger.warning('Loaded net - current net:')
267
+ for v in sorted(list(load_net_keys - crt_net_keys)):
268
+ logger.warning(f' {v}')
269
+
270
+ # check the size for the same keys
271
+ if not strict:
272
+ common_keys = crt_net_keys & load_net_keys
273
+ for k in common_keys:
274
+ if crt_net[k].size() != load_net[k].size():
275
+ logger.warning(f'Size different, ignore [{k}]: crt_net: '
276
+ f'{crt_net[k].shape}; load_net: {load_net[k].shape}')
277
+ load_net[k + '.ignore'] = load_net.pop(k)
278
+
279
+ def load_network(self, net, load_path, strict=True, param_key='params'):
280
+ """Load network.
281
+
282
+ Args:
283
+ load_path (str): The path of networks to be loaded.
284
+ net (nn.Module): Network.
285
+ strict (bool): Whether strictly loaded.
286
+ param_key (str): The parameter key of loaded network. If set to
287
+ None, use the root 'path'.
288
+ Default: 'params'.
289
+ """
290
+ logger = get_root_logger()
291
+ net = self.get_bare_model(net)
292
+ load_net = torch.load(load_path, map_location=lambda storage, loc: storage)
293
+ if param_key is not None:
294
+ if param_key not in load_net and 'params' in load_net:
295
+ param_key = 'params'
296
+ logger.info('Loading: params_ema does not exist, use params.')
297
+ load_net = load_net[param_key]
298
+ logger.info(f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].')
299
+ # remove unnecessary 'module.'
300
+ for k, v in deepcopy(load_net).items():
301
+ if k.startswith('module.'):
302
+ load_net[k[7:]] = v
303
+ load_net.pop(k)
304
+ self._print_different_keys_loading(net, load_net, strict)
305
+ net.load_state_dict(load_net, strict=strict)
306
+
307
+ @master_only
308
+ def save_training_state(self, epoch, current_iter):
309
+ """Save training states during training, which will be used for
310
+ resuming.
311
+
312
+ Args:
313
+ epoch (int): Current epoch.
314
+ current_iter (int): Current iteration.
315
+ """
316
+ if current_iter != -1:
317
+ state = {'epoch': epoch, 'iter': current_iter, 'optimizers': [], 'schedulers': []}
318
+ for o in self.optimizers:
319
+ state['optimizers'].append(o.state_dict())
320
+ for s in self.schedulers:
321
+ state['schedulers'].append(s.state_dict())
322
+ save_filename = f'{current_iter}.state'
323
+ save_path = os.path.join(self.opt['path']['training_states'], save_filename)
324
+
325
+ # avoid occasional writing errors
326
+ retry = 3
327
+ while retry > 0:
328
+ try:
329
+ torch.save(state, save_path)
330
+ except Exception as e:
331
+ logger = get_root_logger()
332
+ logger.warning(f'Save training state error: {e}, remaining retry times: {retry - 1}')
333
+ time.sleep(1)
334
+ else:
335
+ break
336
+ finally:
337
+ retry -= 1
338
+ if retry == 0:
339
+ logger.warning(f'Still cannot save {save_path}. Just ignore it.')
340
+ # raise IOError(f'Cannot save {save_path}.')
341
+
342
+ def resume_training(self, resume_state):
343
+ """Reload the optimizers and schedulers for resumed training.
344
+
345
+ Args:
346
+ resume_state (dict): Resume state.
347
+ """
348
+ resume_optimizers = resume_state['optimizers']
349
+ resume_schedulers = resume_state['schedulers']
350
+ assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
351
+ assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
352
+ for i, o in enumerate(resume_optimizers):
353
+ self.optimizers[i].load_state_dict(o)
354
+ for i, s in enumerate(resume_schedulers):
355
+ self.schedulers[i].load_state_dict(s)
356
+
357
+ def reduce_loss_dict(self, loss_dict):
358
+ """reduce loss dict.
359
+
360
+ In distributed training, it averages the losses among different GPUs .
361
+
362
+ Args:
363
+ loss_dict (OrderedDict): Loss dict.
364
+ """
365
+ with torch.no_grad():
366
+ if self.opt['dist']:
367
+ keys = []
368
+ losses = []
369
+ for name, value in loss_dict.items():
370
+ keys.append(name)
371
+ losses.append(value)
372
+ losses = torch.stack(losses, 0)
373
+ torch.distributed.reduce(losses, dst=0)
374
+ if self.opt['rank'] == 0:
375
+ losses /= self.opt['world_size']
376
+ loss_dict = {key: loss for key, loss in zip(keys, losses)}
377
+
378
+ log_dict = OrderedDict()
379
+ for name, value in loss_dict.items():
380
+ log_dict[name] = value.mean().item()
381
+
382
+ return log_dict
basicsr/models/color_model.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from collections import OrderedDict
4
+ from os import path as osp
5
+ from tqdm import tqdm
6
+ import numpy as np
7
+
8
+ from basicsr.archs import build_network
9
+ from basicsr.losses import build_loss
10
+ from basicsr.metrics import calculate_metric
11
+ from basicsr.utils import get_root_logger, imwrite, tensor2img
12
+ from basicsr.utils.img_util import tensor_lab2rgb
13
+ from basicsr.utils.dist_util import master_only
14
+ from basicsr.utils.registry import MODEL_REGISTRY
15
+ from .base_model import BaseModel
16
+ from basicsr.metrics.custom_fid import INCEPTION_V3_FID, get_activations, calculate_activation_statistics, calculate_frechet_distance
17
+ from basicsr.utils.color_enhance import color_enhacne_blend
18
+
19
+
20
+ @MODEL_REGISTRY.register()
21
+ class ColorModel(BaseModel):
22
+ """Colorization model for single image colorization."""
23
+
24
+ def __init__(self, opt):
25
+ super(ColorModel, self).__init__(opt)
26
+
27
+ # define network net_g
28
+ self.net_g = build_network(opt['network_g'])
29
+ self.net_g = self.model_to_device(self.net_g)
30
+ self.print_network(self.net_g)
31
+
32
+ # load pretrained model for net_g
33
+ load_path = self.opt['path'].get('pretrain_network_g', None)
34
+ if load_path is not None:
35
+ param_key = self.opt['path'].get('param_key_g', 'params')
36
+ self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
37
+
38
+ if self.is_train:
39
+ self.init_training_settings()
40
+
41
+ def init_training_settings(self):
42
+ train_opt = self.opt['train']
43
+
44
+ self.ema_decay = train_opt.get('ema_decay', 0)
45
+ if self.ema_decay > 0:
46
+ logger = get_root_logger()
47
+ logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
48
+ # define network net_g with Exponential Moving Average (EMA)
49
+ # net_g_ema is used only for testing on one GPU and saving
50
+ # There is no need to wrap with DistributedDataParallel
51
+ self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
52
+ # load pretrained model
53
+ load_path = self.opt['path'].get('pretrain_network_g', None)
54
+ if load_path is not None:
55
+ self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
56
+ else:
57
+ self.model_ema(0) # copy net_g weight
58
+ self.net_g_ema.eval()
59
+
60
+ # define network net_d
61
+ self.net_d = build_network(self.opt['network_d'])
62
+ self.net_d = self.model_to_device(self.net_d)
63
+ self.print_network(self.net_d)
64
+
65
+ # load pretrained model for net_d
66
+ load_path = self.opt['path'].get('pretrain_network_d', None)
67
+ if load_path is not None:
68
+ param_key = self.opt['path'].get('param_key_d', 'params')
69
+ self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key)
70
+
71
+ self.net_g.train()
72
+ self.net_d.train()
73
+
74
+ # define losses
75
+ if train_opt.get('pixel_opt'):
76
+ self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
77
+ else:
78
+ self.cri_pix = None
79
+
80
+ if train_opt.get('perceptual_opt'):
81
+ self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
82
+ else:
83
+ self.cri_perceptual = None
84
+
85
+ if train_opt.get('gan_opt'):
86
+ self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
87
+ else:
88
+ self.cri_gan = None
89
+
90
+ if self.cri_pix is None and self.cri_perceptual is None:
91
+ raise ValueError('Both pixel and perceptual losses are None.')
92
+
93
+ if train_opt.get('colorfulness_opt'):
94
+ self.cri_colorfulness = build_loss(train_opt['colorfulness_opt']).to(self.device)
95
+ else:
96
+ self.cri_colorfulness = None
97
+
98
+ # set up optimizers and schedulers
99
+ self.setup_optimizers()
100
+ self.setup_schedulers()
101
+
102
+ # set real dataset cache for fid metric computing
103
+ self.real_mu, self.real_sigma = None, None
104
+ if self.opt['val'].get('metrics') is not None and self.opt['val']['metrics'].get('fid') is not None:
105
+ self._prepare_inception_model_fid()
106
+
107
+ def setup_optimizers(self):
108
+ train_opt = self.opt['train']
109
+ # optim_params_g = []
110
+ # for k, v in self.net_g.named_parameters():
111
+ # if v.requires_grad:
112
+ # optim_params_g.append(v)
113
+ # else:
114
+ # logger = get_root_logger()
115
+ # logger.warning(f'Params {k} will not be optimized.')
116
+ optim_params_g = self.net_g.parameters()
117
+
118
+ # optimizer g
119
+ optim_type = train_opt['optim_g'].pop('type')
120
+ self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g'])
121
+ self.optimizers.append(self.optimizer_g)
122
+
123
+ # optimizer d
124
+ optim_type = train_opt['optim_d'].pop('type')
125
+ self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
126
+ self.optimizers.append(self.optimizer_d)
127
+
128
+ def feed_data(self, data):
129
+ self.lq = data['lq'].to(self.device)
130
+ self.lq_rgb = tensor_lab2rgb(torch.cat([self.lq, torch.zeros_like(self.lq), torch.zeros_like(self.lq)], dim=1))
131
+ if 'gt' in data:
132
+ self.gt = data['gt'].to(self.device)
133
+ self.gt_lab = torch.cat([self.lq, self.gt], dim=1)
134
+ self.gt_rgb = tensor_lab2rgb(self.gt_lab)
135
+
136
+ if self.opt['train'].get('color_enhance', False):
137
+ for i in range(self.gt_rgb.shape[0]):
138
+ self.gt_rgb[i] = color_enhacne_blend(self.gt_rgb[i], factor=self.opt['train'].get('color_enhance_factor'))
139
+
140
+ def optimize_parameters(self, current_iter):
141
+ # optimize net_g
142
+ for p in self.net_d.parameters():
143
+ p.requires_grad = False
144
+ self.optimizer_g.zero_grad()
145
+
146
+ self.output_ab = self.net_g(self.lq_rgb)
147
+ self.output_lab = torch.cat([self.lq, self.output_ab], dim=1)
148
+ self.output_rgb = tensor_lab2rgb(self.output_lab)
149
+
150
+ l_g_total = 0
151
+ loss_dict = OrderedDict()
152
+ # pixel loss
153
+ if self.cri_pix:
154
+ l_g_pix = self.cri_pix(self.output_ab, self.gt)
155
+ l_g_total += l_g_pix
156
+ loss_dict['l_g_pix'] = l_g_pix
157
+
158
+ # perceptual loss
159
+ if self.cri_perceptual:
160
+ l_g_percep, l_g_style = self.cri_perceptual(self.output_rgb, self.gt_rgb)
161
+ if l_g_percep is not None:
162
+ l_g_total += l_g_percep
163
+ loss_dict['l_g_percep'] = l_g_percep
164
+ if l_g_style is not None:
165
+ l_g_total += l_g_style
166
+ loss_dict['l_g_style'] = l_g_style
167
+ # gan loss
168
+ if self.cri_gan:
169
+ fake_g_pred = self.net_d(self.output_rgb)
170
+ l_g_gan = self.cri_gan(fake_g_pred, target_is_real=True, is_disc=False)
171
+ l_g_total += l_g_gan
172
+ loss_dict['l_g_gan'] = l_g_gan
173
+ # colorfulness loss
174
+ if self.cri_colorfulness:
175
+ l_g_color = self.cri_colorfulness(self.output_rgb)
176
+ l_g_total += l_g_color
177
+ loss_dict['l_g_color'] = l_g_color
178
+
179
+ l_g_total.backward()
180
+ self.optimizer_g.step()
181
+
182
+ # optimize net_d
183
+ for p in self.net_d.parameters():
184
+ p.requires_grad = True
185
+ self.optimizer_d.zero_grad()
186
+
187
+ real_d_pred = self.net_d(self.gt_rgb)
188
+ fake_d_pred = self.net_d(self.output_rgb.detach())
189
+ l_d = self.cri_gan(real_d_pred, target_is_real=True, is_disc=True) + self.cri_gan(fake_d_pred, target_is_real=False, is_disc=True)
190
+ loss_dict['l_d'] = l_d
191
+ loss_dict['real_score'] = real_d_pred.detach().mean()
192
+ loss_dict['fake_score'] = fake_d_pred.detach().mean()
193
+
194
+ l_d.backward()
195
+ self.optimizer_d.step()
196
+
197
+ self.log_dict = self.reduce_loss_dict(loss_dict)
198
+
199
+ if self.ema_decay > 0:
200
+ self.model_ema(decay=self.ema_decay)
201
+
202
+ def get_current_visuals(self):
203
+ out_dict = OrderedDict()
204
+ out_dict['lq'] = self.lq_rgb.detach().cpu()
205
+ out_dict['result'] = self.output_rgb.detach().cpu()
206
+ if self.opt['logger'].get('save_snapshot_verbose', False): # only for verbose
207
+ self.output_lab_chroma = torch.cat([torch.ones_like(self.lq) * 50, self.output_ab], dim=1)
208
+ self.output_rgb_chroma = tensor_lab2rgb(self.output_lab_chroma)
209
+ out_dict['result_chroma'] = self.output_rgb_chroma.detach().cpu()
210
+
211
+ if hasattr(self, 'gt'):
212
+ out_dict['gt'] = self.gt_rgb.detach().cpu()
213
+ if self.opt['logger'].get('save_snapshot_verbose', False): # only for verbose
214
+ self.gt_lab_chroma = torch.cat([torch.ones_like(self.lq) * 50, self.gt], dim=1)
215
+ self.gt_rgb_chroma = tensor_lab2rgb(self.gt_lab_chroma)
216
+ out_dict['gt_chroma'] = self.gt_rgb_chroma.detach().cpu()
217
+ return out_dict
218
+
219
+ def test(self):
220
+ if hasattr(self, 'net_g_ema'):
221
+ self.net_g_ema.eval()
222
+ with torch.no_grad():
223
+ self.output_ab = self.net_g_ema(self.lq_rgb)
224
+ self.output_lab = torch.cat([self.lq, self.output_ab], dim=1)
225
+ self.output_rgb = tensor_lab2rgb(self.output_lab)
226
+ else:
227
+ self.net_g.eval()
228
+ with torch.no_grad():
229
+ self.output_ab = self.net_g(self.lq_rgb)
230
+ self.output_lab = torch.cat([self.lq, self.output_ab], dim=1)
231
+ self.output_rgb = tensor_lab2rgb(self.output_lab)
232
+ self.net_g.train()
233
+
234
+ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
235
+ if self.opt['rank'] == 0:
236
+ self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
237
+
238
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
239
+ dataset_name = dataloader.dataset.opt['name']
240
+ with_metrics = self.opt['val'].get('metrics') is not None
241
+ use_pbar = self.opt['val'].get('pbar', False)
242
+
243
+ if with_metrics and not hasattr(self, 'metric_results'): # only execute in the first run
244
+ self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
245
+ # initialize the best metric results for each dataset_name (supporting multiple validation datasets)
246
+ if with_metrics:
247
+ self._initialize_best_metric_results(dataset_name)
248
+ # zero self.metric_results
249
+ if with_metrics:
250
+ self.metric_results = {metric: 0 for metric in self.metric_results}
251
+
252
+ metric_data = dict()
253
+ if use_pbar:
254
+ pbar = tqdm(total=len(dataloader), unit='image')
255
+
256
+ if self.opt['val']['metrics'].get('fid') is not None:
257
+ fake_acts_set, acts_set = [], []
258
+
259
+ for idx, val_data in enumerate(dataloader):
260
+ # if idx == 100:
261
+ # break
262
+ img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
263
+ if hasattr(self, 'gt'):
264
+ del self.gt
265
+ self.feed_data(val_data)
266
+ self.test()
267
+
268
+ visuals = self.get_current_visuals()
269
+ sr_img = tensor2img([visuals['result']])
270
+ metric_data['img'] = sr_img
271
+ if 'gt' in visuals:
272
+ gt_img = tensor2img([visuals['gt']])
273
+ metric_data['img2'] = gt_img
274
+
275
+ torch.cuda.empty_cache()
276
+
277
+ if save_img:
278
+ if self.opt['is_train']:
279
+ save_dir = osp.join(self.opt['path']['visualization'], img_name)
280
+ for key in visuals:
281
+ save_path = os.path.join(save_dir, '{}_{}.png'.format(current_iter, key))
282
+ img = tensor2img(visuals[key])
283
+ imwrite(img, save_path)
284
+ else:
285
+ if self.opt['val']['suffix']:
286
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
287
+ f'{img_name}_{self.opt["val"]["suffix"]}.png')
288
+ else:
289
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
290
+ f'{img_name}_{self.opt["name"]}.png')
291
+ imwrite(sr_img, save_img_path)
292
+
293
+ if with_metrics:
294
+ # calculate metrics
295
+ for name, opt_ in self.opt['val']['metrics'].items():
296
+ if name == 'fid':
297
+ pred, gt = visuals['result'].cuda(), visuals['gt'].cuda()
298
+ fake_act = get_activations(pred, self.inception_model_fid, 1)
299
+ fake_acts_set.append(fake_act)
300
+ if self.real_mu is None:
301
+ real_act = get_activations(gt, self.inception_model_fid, 1)
302
+ acts_set.append(real_act)
303
+ else:
304
+ self.metric_results[name] += calculate_metric(metric_data, opt_)
305
+ if use_pbar:
306
+ pbar.update(1)
307
+ pbar.set_description(f'Test {img_name}')
308
+ if use_pbar:
309
+ pbar.close()
310
+
311
+ if with_metrics:
312
+ if self.opt['val']['metrics'].get('fid') is not None:
313
+ if self.real_mu is None:
314
+ acts_set = np.concatenate(acts_set, 0)
315
+ self.real_mu, self.real_sigma = calculate_activation_statistics(acts_set)
316
+ fake_acts_set = np.concatenate(fake_acts_set, 0)
317
+ fake_mu, fake_sigma = calculate_activation_statistics(fake_acts_set)
318
+
319
+ fid_score = calculate_frechet_distance(self.real_mu, self.real_sigma, fake_mu, fake_sigma)
320
+ self.metric_results['fid'] = fid_score
321
+
322
+ for metric in self.metric_results.keys():
323
+ if metric != 'fid':
324
+ self.metric_results[metric] /= (idx + 1)
325
+ # update the best metric result
326
+ self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter)
327
+
328
+ self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
329
+
330
+ def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
331
+ log_str = f'Validation {dataset_name}\n'
332
+ for metric, value in self.metric_results.items():
333
+ log_str += f'\t # {metric}: {value:.4f}'
334
+ if hasattr(self, 'best_metric_results'):
335
+ log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
336
+ f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
337
+ log_str += '\n'
338
+
339
+ logger = get_root_logger()
340
+ logger.info(log_str)
341
+ if tb_logger:
342
+ for metric, value in self.metric_results.items():
343
+ tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter)
344
+
345
+ def _prepare_inception_model_fid(self, path='pretrain/inception_v3_google-1a9a5a14.pth'):
346
+ incep_state_dict = torch.load(path, map_location='cpu')
347
+ block_idx = INCEPTION_V3_FID.BLOCK_INDEX_BY_DIM[2048]
348
+ self.inception_model_fid = INCEPTION_V3_FID(incep_state_dict, [block_idx])
349
+ self.inception_model_fid.cuda()
350
+ self.inception_model_fid.eval()
351
+
352
+ @master_only
353
+ def save_training_images(self, current_iter):
354
+ visuals = self.get_current_visuals()
355
+ save_dir = osp.join(self.opt['root_path'], 'experiments', self.opt['name'], 'training_images_snapshot')
356
+ os.makedirs(save_dir, exist_ok=True)
357
+
358
+ for key in visuals:
359
+ save_path = os.path.join(save_dir, '{}_{}.png'.format(current_iter, key))
360
+ img = tensor2img(visuals[key])
361
+ imwrite(img, save_path)
362
+
363
+ def save(self, epoch, current_iter):
364
+ if hasattr(self, 'net_g_ema'):
365
+ self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
366
+ else:
367
+ self.save_network(self.net_g, 'net_g', current_iter)
368
+ self.save_network(self.net_d, 'net_d', current_iter)
369
+ self.save_training_state(epoch, current_iter)
basicsr/models/lr_scheduler.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections import Counter
3
+ from torch.optim.lr_scheduler import _LRScheduler
4
+
5
+
6
+ class MultiStepRestartLR(_LRScheduler):
7
+ """ MultiStep with restarts learning rate scheme.
8
+
9
+ Args:
10
+ optimizer (torch.nn.optimizer): Torch optimizer.
11
+ milestones (list): Iterations that will decrease learning rate.
12
+ gamma (float): Decrease ratio. Default: 0.1.
13
+ restarts (list): Restart iterations. Default: [0].
14
+ restart_weights (list): Restart weights at each restart iteration.
15
+ Default: [1].
16
+ last_epoch (int): Used in _LRScheduler. Default: -1.
17
+ """
18
+
19
+ def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1):
20
+ self.milestones = Counter(milestones)
21
+ self.gamma = gamma
22
+ self.restarts = restarts
23
+ self.restart_weights = restart_weights
24
+ assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.'
25
+ super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
26
+
27
+ def get_lr(self):
28
+ if self.last_epoch in self.restarts:
29
+ weight = self.restart_weights[self.restarts.index(self.last_epoch)]
30
+ return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
31
+ if self.last_epoch not in self.milestones:
32
+ return [group['lr'] for group in self.optimizer.param_groups]
33
+ return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups]
34
+
35
+
36
+ def get_position_from_periods(iteration, cumulative_period):
37
+ """Get the position from a period list.
38
+
39
+ It will return the index of the right-closest number in the period list.
40
+ For example, the cumulative_period = [100, 200, 300, 400],
41
+ if iteration == 50, return 0;
42
+ if iteration == 210, return 2;
43
+ if iteration == 300, return 2.
44
+
45
+ Args:
46
+ iteration (int): Current iteration.
47
+ cumulative_period (list[int]): Cumulative period list.
48
+
49
+ Returns:
50
+ int: The position of the right-closest number in the period list.
51
+ """
52
+ for i, period in enumerate(cumulative_period):
53
+ if iteration <= period:
54
+ return i
55
+
56
+
57
+ class CosineAnnealingRestartLR(_LRScheduler):
58
+ """ Cosine annealing with restarts learning rate scheme.
59
+
60
+ An example of config:
61
+ periods = [10, 10, 10, 10]
62
+ restart_weights = [1, 0.5, 0.5, 0.5]
63
+ eta_min=1e-7
64
+
65
+ It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
66
+ scheduler will restart with the weights in restart_weights.
67
+
68
+ Args:
69
+ optimizer (torch.nn.optimizer): Torch optimizer.
70
+ periods (list): Period for each cosine anneling cycle.
71
+ restart_weights (list): Restart weights at each restart iteration.
72
+ Default: [1].
73
+ eta_min (float): The minimum lr. Default: 0.
74
+ last_epoch (int): Used in _LRScheduler. Default: -1.
75
+ """
76
+
77
+ def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1):
78
+ self.periods = periods
79
+ self.restart_weights = restart_weights
80
+ self.eta_min = eta_min
81
+ assert (len(self.periods) == len(
82
+ self.restart_weights)), 'periods and restart_weights should have the same length.'
83
+ self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))]
84
+ super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
85
+
86
+ def get_lr(self):
87
+ idx = get_position_from_periods(self.last_epoch, self.cumulative_period)
88
+ current_weight = self.restart_weights[idx]
89
+ nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
90
+ current_period = self.periods[idx]
91
+
92
+ return [
93
+ self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
94
+ (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period)))
95
+ for base_lr in self.base_lrs
96
+ ]
basicsr/train.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import math
4
+ import time
5
+ import torch
6
+ import warnings
7
+
8
+ warnings.filterwarnings("ignore")
9
+
10
+ from os import path as osp
11
+
12
+ from basicsr.data import build_dataloader, build_dataset
13
+ from basicsr.data.data_sampler import EnlargedSampler
14
+ from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
15
+ from basicsr.models import build_model
16
+ from basicsr.utils import (AvgTimer, MessageLogger, check_resume, get_env_info, get_root_logger, get_time_str,
17
+ init_tb_logger, init_wandb_logger, make_exp_dirs, mkdir_and_rename, scandir)
18
+ from basicsr.utils.options import copy_opt_file, dict2str, parse_options
19
+
20
+
21
+ def init_tb_loggers(opt):
22
+ # initialize wandb logger before tensorboard logger to allow proper sync
23
+ if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project')
24
+ is not None) and ('debug' not in opt['name']):
25
+ assert opt['logger'].get('use_tb_logger') is True, ('should turn on tensorboard when using wandb')
26
+ init_wandb_logger(opt)
27
+ tb_logger = None
28
+ if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name']:
29
+ tb_logger = init_tb_logger(log_dir=osp.join(opt['root_path'], 'tb_logger', opt['name']))
30
+ return tb_logger
31
+
32
+
33
+ def create_train_val_dataloader(opt, logger):
34
+ # create train and val dataloaders
35
+ train_loader, val_loaders = None, []
36
+ for phase, dataset_opt in opt['datasets'].items():
37
+ if phase == 'train':
38
+ dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1)
39
+ train_set = build_dataset(dataset_opt)
40
+ train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio)
41
+ train_loader = build_dataloader(
42
+ train_set,
43
+ dataset_opt,
44
+ num_gpu=opt['num_gpu'],
45
+ dist=opt['dist'],
46
+ sampler=train_sampler,
47
+ seed=opt['manual_seed'])
48
+
49
+ num_iter_per_epoch = math.ceil(
50
+ len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size']))
51
+ total_iters = int(opt['train']['total_iter'])
52
+ total_epochs = math.ceil(total_iters / (num_iter_per_epoch))
53
+ logger.info('Training statistics:'
54
+ f'\n\tNumber of train images: {len(train_set)}'
55
+ f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}'
56
+ f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}'
57
+ f'\n\tWorld size (gpu number): {opt["world_size"]}'
58
+ f'\n\tRequire iter number per epoch: {num_iter_per_epoch}'
59
+ f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.')
60
+ elif phase.split('_')[0] == 'val':
61
+ val_set = build_dataset(dataset_opt)
62
+ val_loader = build_dataloader(
63
+ val_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed'])
64
+ logger.info(f'Number of val images/folders in {dataset_opt["name"]}: {len(val_set)}')
65
+ val_loaders.append(val_loader)
66
+ else:
67
+ raise ValueError(f'Dataset phase {phase} is not recognized.')
68
+
69
+ return train_loader, train_sampler, val_loaders, total_epochs, total_iters
70
+
71
+
72
+ def load_resume_state(opt):
73
+ resume_state_path = None
74
+ if opt['auto_resume']:
75
+ state_path = osp.join(opt['root_path'], 'experiments', opt['name'], 'training_states')
76
+ if osp.isdir(state_path):
77
+ states = list(scandir(state_path, suffix='state', recursive=False, full_path=False))
78
+ if len(states) != 0:
79
+ states = [float(v.split('.state')[0]) for v in states]
80
+ resume_state_path = osp.join(state_path, f'{max(states):.0f}.state')
81
+ opt['path']['resume_state'] = resume_state_path
82
+ else:
83
+ if opt['path'].get('resume_state'):
84
+ resume_state_path = opt['path']['resume_state']
85
+
86
+ if resume_state_path is None:
87
+ resume_state = None
88
+ else:
89
+ device_id = torch.cuda.current_device()
90
+ resume_state = torch.load(resume_state_path, map_location=lambda storage, loc: storage.cuda(device_id))
91
+ check_resume(opt, resume_state['iter'])
92
+ return resume_state
93
+
94
+
95
+ def train_pipeline(root_path):
96
+ # parse options, set distributed setting, set ramdom seed
97
+ opt, args = parse_options(root_path, is_train=True)
98
+ opt['root_path'] = root_path
99
+
100
+ torch.backends.cudnn.benchmark = True
101
+ # torch.backends.cudnn.deterministic = True
102
+
103
+ # load resume states if necessary
104
+ resume_state = load_resume_state(opt)
105
+ # mkdir for experiments and logger
106
+ if resume_state is None:
107
+ make_exp_dirs(opt)
108
+ if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name'] and opt['rank'] == 0:
109
+ mkdir_and_rename(osp.join(opt['root_path'], 'tb_logger', opt['name']))
110
+
111
+ # copy the yml file to the experiment root
112
+ copy_opt_file(args.opt, opt['path']['experiments_root'])
113
+
114
+ # WARNING: should not use get_root_logger in the above codes, including the called functions
115
+ # Otherwise the logger will not be properly initialized
116
+ log_file = osp.join(opt['path']['log'], f"train_{opt['name']}_{get_time_str()}.log")
117
+ logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
118
+ logger.info(get_env_info())
119
+ logger.info(dict2str(opt))
120
+ # initialize wandb and tb loggers
121
+ tb_logger = init_tb_loggers(opt)
122
+
123
+ # create train and validation dataloaders
124
+ result = create_train_val_dataloader(opt, logger)
125
+ train_loader, train_sampler, val_loaders, total_epochs, total_iters = result
126
+
127
+ # create model
128
+ model = build_model(opt)
129
+ if resume_state: # resume training
130
+ model.resume_training(resume_state) # handle optimizers and schedulers
131
+ logger.info(f"Resuming training from epoch: {resume_state['epoch']}, " f"iter: {resume_state['iter']}.")
132
+ start_epoch = resume_state['epoch']
133
+ current_iter = resume_state['iter']
134
+ else:
135
+ start_epoch = 0
136
+ current_iter = 0
137
+
138
+ # create message logger (formatted outputs)
139
+ msg_logger = MessageLogger(opt, current_iter, tb_logger)
140
+
141
+ # dataloader prefetcher
142
+ prefetch_mode = opt['datasets']['train'].get('prefetch_mode')
143
+ if prefetch_mode is None or prefetch_mode == 'cpu':
144
+ prefetcher = CPUPrefetcher(train_loader)
145
+ elif prefetch_mode == 'cuda':
146
+ prefetcher = CUDAPrefetcher(train_loader, opt)
147
+ logger.info(f'Use {prefetch_mode} prefetch dataloader')
148
+ if opt['datasets']['train'].get('pin_memory') is not True:
149
+ raise ValueError('Please set pin_memory=True for CUDAPrefetcher.')
150
+ else:
151
+ raise ValueError(f'Wrong prefetch_mode {prefetch_mode}.' "Supported ones are: None, 'cuda', 'cpu'.")
152
+
153
+ # training
154
+ logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter}')
155
+ data_timer, iter_timer = AvgTimer(), AvgTimer()
156
+ start_time = time.time()
157
+
158
+ for epoch in range(start_epoch, total_epochs + 1):
159
+ train_sampler.set_epoch(epoch)
160
+ prefetcher.reset()
161
+ train_data = prefetcher.next()
162
+
163
+ while train_data is not None:
164
+ data_timer.record()
165
+
166
+ current_iter += 1
167
+ if current_iter > total_iters:
168
+ break
169
+ # update learning rate
170
+ model.update_learning_rate(current_iter, warmup_iter=opt['train'].get('warmup_iter', -1))
171
+ # training
172
+ model.feed_data(train_data)
173
+ model.optimize_parameters(current_iter)
174
+ iter_timer.record()
175
+ if current_iter == 1:
176
+ # reset start time in msg_logger for more accurate eta_time
177
+ # not work in resume mode
178
+ msg_logger.reset_start_time()
179
+ # log
180
+ if current_iter % opt['logger']['print_freq'] == 0:
181
+ log_vars = {'epoch': epoch, 'iter': current_iter}
182
+ log_vars.update({'lrs': model.get_current_learning_rate()})
183
+ log_vars.update({'time': iter_timer.get_avg_time(), 'data_time': data_timer.get_avg_time()})
184
+ log_vars.update(model.get_current_log())
185
+ msg_logger(log_vars)
186
+
187
+ # save training images snapshot save_snapshot_freq
188
+ if opt['logger'][
189
+ 'save_snapshot_freq'] is not None and current_iter % opt['logger']['save_snapshot_freq'] == 0:
190
+ model.save_training_images(current_iter)
191
+
192
+ # save models and training states
193
+ if current_iter % opt['logger']['save_checkpoint_freq'] == 0:
194
+ logger.info('Saving models and training states.')
195
+ model.save(epoch, current_iter)
196
+
197
+ # validation
198
+ if opt.get('val') is not None and (current_iter % opt['val']['val_freq'] == 0):
199
+ if len(val_loaders) > 1:
200
+ logger.warning('Multiple validation datasets are *only* supported by SRModel.')
201
+ for val_loader in val_loaders:
202
+ model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
203
+
204
+ data_timer.start()
205
+ iter_timer.start()
206
+ train_data = prefetcher.next()
207
+ # end of iter
208
+
209
+ # end of epoch
210
+
211
+ consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time)))
212
+ logger.info(f'End of training. Time consumed: {consumed_time}')
213
+ logger.info('Save the latest model.')
214
+ model.save(epoch=-1, current_iter=-1) # -1 stands for the latest
215
+ if opt.get('val') is not None:
216
+ for val_loader in val_loaders:
217
+ model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
218
+ if tb_logger:
219
+ tb_logger.close()
220
+
221
+
222
+ if __name__ == '__main__':
223
+ root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
224
+ train_pipeline(root_path)
basicsr/utils/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .diffjpeg import DiffJPEG
2
+ from .file_client import FileClient
3
+ from .img_process_util import USMSharp, usm_sharp
4
+ from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img
5
+ from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger
6
+ from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt
7
+
8
+ __all__ = [
9
+ # file_client.py
10
+ 'FileClient',
11
+ # img_util.py
12
+ 'img2tensor',
13
+ 'tensor2img',
14
+ 'imfrombytes',
15
+ 'imwrite',
16
+ 'crop_border',
17
+ # logger.py
18
+ 'MessageLogger',
19
+ 'AvgTimer',
20
+ 'init_tb_logger',
21
+ 'init_wandb_logger',
22
+ 'get_root_logger',
23
+ 'get_env_info',
24
+ # misc.py
25
+ 'set_random_seed',
26
+ 'get_time_str',
27
+ 'mkdir_and_rename',
28
+ 'make_exp_dirs',
29
+ 'scandir',
30
+ 'check_resume',
31
+ 'sizeof_fmt',
32
+ # diffjpeg
33
+ 'DiffJPEG',
34
+ # img_process_util
35
+ 'USMSharp',
36
+ 'usm_sharp'
37
+ ]
basicsr/utils/color_enhance.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision.transforms import ToTensor, Grayscale
2
+
3
+
4
+ def color_enhacne_blend(x, factor=1.2):
5
+ x_g = Grayscale(3)(x)
6
+ out = x_g * (1.0 - factor) + x * factor
7
+ out[out < 0] = 0
8
+ out[out > 1] = 1
9
+ return out
basicsr/utils/diffjpeg.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified from https://github.com/mlomnitz/DiffJPEG
3
+
4
+ For images not divisible by 8
5
+ https://dsp.stackexchange.com/questions/35339/jpeg-dct-padding/35343#35343
6
+ """
7
+ import itertools
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+
13
+ # ------------------------ utils ------------------------#
14
+ y_table = np.array(
15
+ [[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60, 55], [14, 13, 16, 24, 40, 57, 69, 56],
16
+ [14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103, 77], [24, 35, 55, 64, 81, 104, 113, 92],
17
+ [49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]],
18
+ dtype=np.float32).T
19
+ y_table = nn.Parameter(torch.from_numpy(y_table))
20
+ c_table = np.empty((8, 8), dtype=np.float32)
21
+ c_table.fill(99)
22
+ c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66], [24, 26, 56, 99], [47, 66, 99, 99]]).T
23
+ c_table = nn.Parameter(torch.from_numpy(c_table))
24
+
25
+
26
+ def diff_round(x):
27
+ """ Differentiable rounding function
28
+ """
29
+ return torch.round(x) + (x - torch.round(x))**3
30
+
31
+
32
+ def quality_to_factor(quality):
33
+ """ Calculate factor corresponding to quality
34
+
35
+ Args:
36
+ quality(float): Quality for jpeg compression.
37
+
38
+ Returns:
39
+ float: Compression factor.
40
+ """
41
+ if quality < 50:
42
+ quality = 5000. / quality
43
+ else:
44
+ quality = 200. - quality * 2
45
+ return quality / 100.
46
+
47
+
48
+ # ------------------------ compression ------------------------#
49
+ class RGB2YCbCrJpeg(nn.Module):
50
+ """ Converts RGB image to YCbCr
51
+ """
52
+
53
+ def __init__(self):
54
+ super(RGB2YCbCrJpeg, self).__init__()
55
+ matrix = np.array([[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], [0.5, -0.418688, -0.081312]],
56
+ dtype=np.float32).T
57
+ self.shift = nn.Parameter(torch.tensor([0., 128., 128.]))
58
+ self.matrix = nn.Parameter(torch.from_numpy(matrix))
59
+
60
+ def forward(self, image):
61
+ """
62
+ Args:
63
+ image(Tensor): batch x 3 x height x width
64
+
65
+ Returns:
66
+ Tensor: batch x height x width x 3
67
+ """
68
+ image = image.permute(0, 2, 3, 1)
69
+ result = torch.tensordot(image, self.matrix, dims=1) + self.shift
70
+ return result.view(image.shape)
71
+
72
+
73
+ class ChromaSubsampling(nn.Module):
74
+ """ Chroma subsampling on CbCr channels
75
+ """
76
+
77
+ def __init__(self):
78
+ super(ChromaSubsampling, self).__init__()
79
+
80
+ def forward(self, image):
81
+ """
82
+ Args:
83
+ image(tensor): batch x height x width x 3
84
+
85
+ Returns:
86
+ y(tensor): batch x height x width
87
+ cb(tensor): batch x height/2 x width/2
88
+ cr(tensor): batch x height/2 x width/2
89
+ """
90
+ image_2 = image.permute(0, 3, 1, 2).clone()
91
+ cb = F.avg_pool2d(image_2[:, 1, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False)
92
+ cr = F.avg_pool2d(image_2[:, 2, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False)
93
+ cb = cb.permute(0, 2, 3, 1)
94
+ cr = cr.permute(0, 2, 3, 1)
95
+ return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3)
96
+
97
+
98
+ class BlockSplitting(nn.Module):
99
+ """ Splitting image into patches
100
+ """
101
+
102
+ def __init__(self):
103
+ super(BlockSplitting, self).__init__()
104
+ self.k = 8
105
+
106
+ def forward(self, image):
107
+ """
108
+ Args:
109
+ image(tensor): batch x height x width
110
+
111
+ Returns:
112
+ Tensor: batch x h*w/64 x h x w
113
+ """
114
+ height, _ = image.shape[1:3]
115
+ batch_size = image.shape[0]
116
+ image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k)
117
+ image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
118
+ return image_transposed.contiguous().view(batch_size, -1, self.k, self.k)
119
+
120
+
121
+ class DCT8x8(nn.Module):
122
+ """ Discrete Cosine Transformation
123
+ """
124
+
125
+ def __init__(self):
126
+ super(DCT8x8, self).__init__()
127
+ tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
128
+ for x, y, u, v in itertools.product(range(8), repeat=4):
129
+ tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos((2 * y + 1) * v * np.pi / 16)
130
+ alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
131
+ self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
132
+ self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float())
133
+
134
+ def forward(self, image):
135
+ """
136
+ Args:
137
+ image(tensor): batch x height x width
138
+
139
+ Returns:
140
+ Tensor: batch x height x width
141
+ """
142
+ image = image - 128
143
+ result = self.scale * torch.tensordot(image, self.tensor, dims=2)
144
+ result.view(image.shape)
145
+ return result
146
+
147
+
148
+ class YQuantize(nn.Module):
149
+ """ JPEG Quantization for Y channel
150
+
151
+ Args:
152
+ rounding(function): rounding function to use
153
+ """
154
+
155
+ def __init__(self, rounding):
156
+ super(YQuantize, self).__init__()
157
+ self.rounding = rounding
158
+ self.y_table = y_table
159
+
160
+ def forward(self, image, factor=1):
161
+ """
162
+ Args:
163
+ image(tensor): batch x height x width
164
+
165
+ Returns:
166
+ Tensor: batch x height x width
167
+ """
168
+ if isinstance(factor, (int, float)):
169
+ image = image.float() / (self.y_table * factor)
170
+ else:
171
+ b = factor.size(0)
172
+ table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
173
+ image = image.float() / table
174
+ image = self.rounding(image)
175
+ return image
176
+
177
+
178
+ class CQuantize(nn.Module):
179
+ """ JPEG Quantization for CbCr channels
180
+
181
+ Args:
182
+ rounding(function): rounding function to use
183
+ """
184
+
185
+ def __init__(self, rounding):
186
+ super(CQuantize, self).__init__()
187
+ self.rounding = rounding
188
+ self.c_table = c_table
189
+
190
+ def forward(self, image, factor=1):
191
+ """
192
+ Args:
193
+ image(tensor): batch x height x width
194
+
195
+ Returns:
196
+ Tensor: batch x height x width
197
+ """
198
+ if isinstance(factor, (int, float)):
199
+ image = image.float() / (self.c_table * factor)
200
+ else:
201
+ b = factor.size(0)
202
+ table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
203
+ image = image.float() / table
204
+ image = self.rounding(image)
205
+ return image
206
+
207
+
208
+ class CompressJpeg(nn.Module):
209
+ """Full JPEG compression algorithm
210
+
211
+ Args:
212
+ rounding(function): rounding function to use
213
+ """
214
+
215
+ def __init__(self, rounding=torch.round):
216
+ super(CompressJpeg, self).__init__()
217
+ self.l1 = nn.Sequential(RGB2YCbCrJpeg(), ChromaSubsampling())
218
+ self.l2 = nn.Sequential(BlockSplitting(), DCT8x8())
219
+ self.c_quantize = CQuantize(rounding=rounding)
220
+ self.y_quantize = YQuantize(rounding=rounding)
221
+
222
+ def forward(self, image, factor=1):
223
+ """
224
+ Args:
225
+ image(tensor): batch x 3 x height x width
226
+
227
+ Returns:
228
+ dict(tensor): Compressed tensor with batch x h*w/64 x 8 x 8.
229
+ """
230
+ y, cb, cr = self.l1(image * 255)
231
+ components = {'y': y, 'cb': cb, 'cr': cr}
232
+ for k in components.keys():
233
+ comp = self.l2(components[k])
234
+ if k in ('cb', 'cr'):
235
+ comp = self.c_quantize(comp, factor=factor)
236
+ else:
237
+ comp = self.y_quantize(comp, factor=factor)
238
+
239
+ components[k] = comp
240
+
241
+ return components['y'], components['cb'], components['cr']
242
+
243
+
244
+ # ------------------------ decompression ------------------------#
245
+
246
+
247
+ class YDequantize(nn.Module):
248
+ """Dequantize Y channel
249
+ """
250
+
251
+ def __init__(self):
252
+ super(YDequantize, self).__init__()
253
+ self.y_table = y_table
254
+
255
+ def forward(self, image, factor=1):
256
+ """
257
+ Args:
258
+ image(tensor): batch x height x width
259
+
260
+ Returns:
261
+ Tensor: batch x height x width
262
+ """
263
+ if isinstance(factor, (int, float)):
264
+ out = image * (self.y_table * factor)
265
+ else:
266
+ b = factor.size(0)
267
+ table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
268
+ out = image * table
269
+ return out
270
+
271
+
272
+ class CDequantize(nn.Module):
273
+ """Dequantize CbCr channel
274
+ """
275
+
276
+ def __init__(self):
277
+ super(CDequantize, self).__init__()
278
+ self.c_table = c_table
279
+
280
+ def forward(self, image, factor=1):
281
+ """
282
+ Args:
283
+ image(tensor): batch x height x width
284
+
285
+ Returns:
286
+ Tensor: batch x height x width
287
+ """
288
+ if isinstance(factor, (int, float)):
289
+ out = image * (self.c_table * factor)
290
+ else:
291
+ b = factor.size(0)
292
+ table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
293
+ out = image * table
294
+ return out
295
+
296
+
297
+ class iDCT8x8(nn.Module):
298
+ """Inverse discrete Cosine Transformation
299
+ """
300
+
301
+ def __init__(self):
302
+ super(iDCT8x8, self).__init__()
303
+ alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
304
+ self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float())
305
+ tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
306
+ for x, y, u, v in itertools.product(range(8), repeat=4):
307
+ tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos((2 * v + 1) * y * np.pi / 16)
308
+ self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
309
+
310
+ def forward(self, image):
311
+ """
312
+ Args:
313
+ image(tensor): batch x height x width
314
+
315
+ Returns:
316
+ Tensor: batch x height x width
317
+ """
318
+ image = image * self.alpha
319
+ result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128
320
+ result.view(image.shape)
321
+ return result
322
+
323
+
324
+ class BlockMerging(nn.Module):
325
+ """Merge patches into image
326
+ """
327
+
328
+ def __init__(self):
329
+ super(BlockMerging, self).__init__()
330
+
331
+ def forward(self, patches, height, width):
332
+ """
333
+ Args:
334
+ patches(tensor) batch x height*width/64, height x width
335
+ height(int)
336
+ width(int)
337
+
338
+ Returns:
339
+ Tensor: batch x height x width
340
+ """
341
+ k = 8
342
+ batch_size = patches.shape[0]
343
+ image_reshaped = patches.view(batch_size, height // k, width // k, k, k)
344
+ image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
345
+ return image_transposed.contiguous().view(batch_size, height, width)
346
+
347
+
348
+ class ChromaUpsampling(nn.Module):
349
+ """Upsample chroma layers
350
+ """
351
+
352
+ def __init__(self):
353
+ super(ChromaUpsampling, self).__init__()
354
+
355
+ def forward(self, y, cb, cr):
356
+ """
357
+ Args:
358
+ y(tensor): y channel image
359
+ cb(tensor): cb channel
360
+ cr(tensor): cr channel
361
+
362
+ Returns:
363
+ Tensor: batch x height x width x 3
364
+ """
365
+
366
+ def repeat(x, k=2):
367
+ height, width = x.shape[1:3]
368
+ x = x.unsqueeze(-1)
369
+ x = x.repeat(1, 1, k, k)
370
+ x = x.view(-1, height * k, width * k)
371
+ return x
372
+
373
+ cb = repeat(cb)
374
+ cr = repeat(cr)
375
+ return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3)
376
+
377
+
378
+ class YCbCr2RGBJpeg(nn.Module):
379
+ """Converts YCbCr image to RGB JPEG
380
+ """
381
+
382
+ def __init__(self):
383
+ super(YCbCr2RGBJpeg, self).__init__()
384
+
385
+ matrix = np.array([[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]], dtype=np.float32).T
386
+ self.shift = nn.Parameter(torch.tensor([0, -128., -128.]))
387
+ self.matrix = nn.Parameter(torch.from_numpy(matrix))
388
+
389
+ def forward(self, image):
390
+ """
391
+ Args:
392
+ image(tensor): batch x height x width x 3
393
+
394
+ Returns:
395
+ Tensor: batch x 3 x height x width
396
+ """
397
+ result = torch.tensordot(image + self.shift, self.matrix, dims=1)
398
+ return result.view(image.shape).permute(0, 3, 1, 2)
399
+
400
+
401
+ class DeCompressJpeg(nn.Module):
402
+ """Full JPEG decompression algorithm
403
+
404
+ Args:
405
+ rounding(function): rounding function to use
406
+ """
407
+
408
+ def __init__(self, rounding=torch.round):
409
+ super(DeCompressJpeg, self).__init__()
410
+ self.c_dequantize = CDequantize()
411
+ self.y_dequantize = YDequantize()
412
+ self.idct = iDCT8x8()
413
+ self.merging = BlockMerging()
414
+ self.chroma = ChromaUpsampling()
415
+ self.colors = YCbCr2RGBJpeg()
416
+
417
+ def forward(self, y, cb, cr, imgh, imgw, factor=1):
418
+ """
419
+ Args:
420
+ compressed(dict(tensor)): batch x h*w/64 x 8 x 8
421
+ imgh(int)
422
+ imgw(int)
423
+ factor(float)
424
+
425
+ Returns:
426
+ Tensor: batch x 3 x height x width
427
+ """
428
+ components = {'y': y, 'cb': cb, 'cr': cr}
429
+ for k in components.keys():
430
+ if k in ('cb', 'cr'):
431
+ comp = self.c_dequantize(components[k], factor=factor)
432
+ height, width = int(imgh / 2), int(imgw / 2)
433
+ else:
434
+ comp = self.y_dequantize(components[k], factor=factor)
435
+ height, width = imgh, imgw
436
+ comp = self.idct(comp)
437
+ components[k] = self.merging(comp, height, width)
438
+ #
439
+ image = self.chroma(components['y'], components['cb'], components['cr'])
440
+ image = self.colors(image)
441
+
442
+ image = torch.min(255 * torch.ones_like(image), torch.max(torch.zeros_like(image), image))
443
+ return image / 255
444
+
445
+
446
+ # ------------------------ main DiffJPEG ------------------------ #
447
+
448
+
449
+ class DiffJPEG(nn.Module):
450
+ """This JPEG algorithm result is slightly different from cv2.
451
+ DiffJPEG supports batch processing.
452
+
453
+ Args:
454
+ differentiable(bool): If True, uses custom differentiable rounding function, if False, uses standard torch.round
455
+ """
456
+
457
+ def __init__(self, differentiable=True):
458
+ super(DiffJPEG, self).__init__()
459
+ if differentiable:
460
+ rounding = diff_round
461
+ else:
462
+ rounding = torch.round
463
+
464
+ self.compress = CompressJpeg(rounding=rounding)
465
+ self.decompress = DeCompressJpeg(rounding=rounding)
466
+
467
+ def forward(self, x, quality):
468
+ """
469
+ Args:
470
+ x (Tensor): Input image, bchw, rgb, [0, 1]
471
+ quality(float): Quality factor for jpeg compression scheme.
472
+ """
473
+ factor = quality
474
+ if isinstance(factor, (int, float)):
475
+ factor = quality_to_factor(factor)
476
+ else:
477
+ for i in range(factor.size(0)):
478
+ factor[i] = quality_to_factor(factor[i])
479
+ h, w = x.size()[-2:]
480
+ h_pad, w_pad = 0, 0
481
+ # why should use 16
482
+ if h % 16 != 0:
483
+ h_pad = 16 - h % 16
484
+ if w % 16 != 0:
485
+ w_pad = 16 - w % 16
486
+ x = F.pad(x, (0, w_pad, 0, h_pad), mode='constant', value=0)
487
+
488
+ y, cb, cr = self.compress(x, factor=factor)
489
+ recovered = self.decompress(y, cb, cr, (h + h_pad), (w + w_pad), factor=factor)
490
+ recovered = recovered[:, :, 0:h, 0:w]
491
+ return recovered
492
+
493
+
494
+ if __name__ == '__main__':
495
+ import cv2
496
+
497
+ from basicsr.utils import img2tensor, tensor2img
498
+
499
+ img_gt = cv2.imread('test.png') / 255.
500
+
501
+ # -------------- cv2 -------------- #
502
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 20]
503
+ _, encimg = cv2.imencode('.jpg', img_gt * 255., encode_param)
504
+ img_lq = np.float32(cv2.imdecode(encimg, 1))
505
+ cv2.imwrite('cv2_JPEG_20.png', img_lq)
506
+
507
+ # -------------- DiffJPEG -------------- #
508
+ jpeger = DiffJPEG(differentiable=False).cuda()
509
+ img_gt = img2tensor(img_gt)
510
+ img_gt = torch.stack([img_gt, img_gt]).cuda()
511
+ quality = img_gt.new_tensor([20, 40])
512
+ out = jpeger(img_gt, quality=quality)
513
+
514
+ cv2.imwrite('pt_JPEG_20.png', tensor2img(out[0]))
515
+ cv2.imwrite('pt_JPEG_40.png', tensor2img(out[1]))
basicsr/utils/dist_util.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
2
+ import functools
3
+ import os
4
+ import subprocess
5
+ import torch
6
+ import torch.distributed as dist
7
+ import torch.multiprocessing as mp
8
+
9
+
10
+ def init_dist(launcher, backend='nccl', **kwargs):
11
+ if mp.get_start_method(allow_none=True) is None:
12
+ mp.set_start_method('spawn')
13
+ if launcher == 'pytorch':
14
+ _init_dist_pytorch(backend, **kwargs)
15
+ elif launcher == 'slurm':
16
+ _init_dist_slurm(backend, **kwargs)
17
+ else:
18
+ raise ValueError(f'Invalid launcher type: {launcher}')
19
+
20
+
21
+ def _init_dist_pytorch(backend, **kwargs):
22
+ rank = int(os.environ['RANK'])
23
+ num_gpus = torch.cuda.device_count()
24
+ torch.cuda.set_device(rank % num_gpus)
25
+ dist.init_process_group(backend=backend, **kwargs)
26
+
27
+
28
+ def _init_dist_slurm(backend, port=None):
29
+ """Initialize slurm distributed training environment.
30
+
31
+ If argument ``port`` is not specified, then the master port will be system
32
+ environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
33
+ environment variable, then a default port ``29500`` will be used.
34
+
35
+ Args:
36
+ backend (str): Backend of torch.distributed.
37
+ port (int, optional): Master port. Defaults to None.
38
+ """
39
+ proc_id = int(os.environ['SLURM_PROCID'])
40
+ ntasks = int(os.environ['SLURM_NTASKS'])
41
+ node_list = os.environ['SLURM_NODELIST']
42
+ num_gpus = torch.cuda.device_count()
43
+ torch.cuda.set_device(proc_id % num_gpus)
44
+ addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1')
45
+ # specify master port
46
+ if port is not None:
47
+ os.environ['MASTER_PORT'] = str(port)
48
+ elif 'MASTER_PORT' in os.environ:
49
+ pass # use MASTER_PORT in the environment variable
50
+ else:
51
+ # 29500 is torch.distributed default port
52
+ os.environ['MASTER_PORT'] = '29500'
53
+ os.environ['MASTER_ADDR'] = addr
54
+ os.environ['WORLD_SIZE'] = str(ntasks)
55
+ os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
56
+ os.environ['RANK'] = str(proc_id)
57
+ dist.init_process_group(backend=backend)
58
+
59
+
60
+ def get_dist_info():
61
+ if dist.is_available():
62
+ initialized = dist.is_initialized()
63
+ else:
64
+ initialized = False
65
+ if initialized:
66
+ rank = dist.get_rank()
67
+ world_size = dist.get_world_size()
68
+ else:
69
+ rank = 0
70
+ world_size = 1
71
+ return rank, world_size
72
+
73
+
74
+ def master_only(func):
75
+
76
+ @functools.wraps(func)
77
+ def wrapper(*args, **kwargs):
78
+ rank, _ = get_dist_info()
79
+ if rank == 0:
80
+ return func(*args, **kwargs)
81
+
82
+ return wrapper
basicsr/utils/download_util.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import requests
3
+ from tqdm import tqdm
4
+
5
+ from .misc import sizeof_fmt
6
+
7
+
8
+ def download_file_from_google_drive(file_id, save_path):
9
+ """Download files from google drive.
10
+
11
+ Ref:
12
+ https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
13
+
14
+ Args:
15
+ file_id (str): File id.
16
+ save_path (str): Save path.
17
+ """
18
+
19
+ session = requests.Session()
20
+ URL = 'https://docs.google.com/uc?export=download'
21
+ params = {'id': file_id}
22
+
23
+ response = session.get(URL, params=params, stream=True)
24
+ token = get_confirm_token(response)
25
+ if token:
26
+ params['confirm'] = token
27
+ response = session.get(URL, params=params, stream=True)
28
+
29
+ # get file size
30
+ response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
31
+ if 'Content-Range' in response_file_size.headers:
32
+ file_size = int(response_file_size.headers['Content-Range'].split('/')[1])
33
+ else:
34
+ file_size = None
35
+
36
+ save_response_content(response, save_path, file_size)
37
+
38
+
39
+ def get_confirm_token(response):
40
+ for key, value in response.cookies.items():
41
+ if key.startswith('download_warning'):
42
+ return value
43
+ return None
44
+
45
+
46
+ def save_response_content(response, destination, file_size=None, chunk_size=32768):
47
+ if file_size is not None:
48
+ pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
49
+
50
+ readable_file_size = sizeof_fmt(file_size)
51
+ else:
52
+ pbar = None
53
+
54
+ with open(destination, 'wb') as f:
55
+ downloaded_size = 0
56
+ for chunk in response.iter_content(chunk_size):
57
+ downloaded_size += chunk_size
58
+ if pbar is not None:
59
+ pbar.update(1)
60
+ pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}')
61
+ if chunk: # filter out keep-alive new chunks
62
+ f.write(chunk)
63
+ if pbar is not None:
64
+ pbar.close()
basicsr/utils/face_util.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import os
4
+ import torch
5
+ from skimage import transform as trans
6
+
7
+ from basicsr.utils import imwrite
8
+
9
+ try:
10
+ import dlib
11
+ except ImportError:
12
+ print('Please install dlib before testing face restoration.' 'Reference: https://github.com/davisking/dlib')
13
+
14
+
15
+ class FaceRestorationHelper(object):
16
+ """Helper for the face restoration pipeline."""
17
+
18
+ def __init__(self, upscale_factor, face_size=512):
19
+ self.upscale_factor = upscale_factor
20
+ self.face_size = (face_size, face_size)
21
+
22
+ # standard 5 landmarks for FFHQ faces with 1024 x 1024
23
+ self.face_template = np.array([[686.77227723, 488.62376238], [586.77227723, 493.59405941],
24
+ [337.91089109, 488.38613861], [437.95049505, 493.51485149],
25
+ [513.58415842, 678.5049505]])
26
+ self.face_template = self.face_template / (1024 // face_size)
27
+ # for estimation the 2D similarity transformation
28
+ self.similarity_trans = trans.SimilarityTransform()
29
+
30
+ self.all_landmarks_5 = []
31
+ self.all_landmarks_68 = []
32
+ self.affine_matrices = []
33
+ self.inverse_affine_matrices = []
34
+ self.cropped_faces = []
35
+ self.restored_faces = []
36
+ self.save_png = True
37
+
38
+ def init_dlib(self, detection_path, landmark5_path, landmark68_path):
39
+ """Initialize the dlib detectors and predictors."""
40
+ self.face_detector = dlib.cnn_face_detection_model_v1(detection_path)
41
+ self.shape_predictor_5 = dlib.shape_predictor(landmark5_path)
42
+ self.shape_predictor_68 = dlib.shape_predictor(landmark68_path)
43
+
44
+ def free_dlib_gpu_memory(self):
45
+ del self.face_detector
46
+ del self.shape_predictor_5
47
+ del self.shape_predictor_68
48
+
49
+ def read_input_image(self, img_path):
50
+ # self.input_img is Numpy array, (h, w, c) with RGB order
51
+ self.input_img = dlib.load_rgb_image(img_path)
52
+
53
+ def detect_faces(self, img_path, upsample_num_times=1, only_keep_largest=False):
54
+ """
55
+ Args:
56
+ img_path (str): Image path.
57
+ upsample_num_times (int): Upsamples the image before running the
58
+ face detector
59
+
60
+ Returns:
61
+ int: Number of detected faces.
62
+ """
63
+ self.read_input_image(img_path)
64
+ det_faces = self.face_detector(self.input_img, upsample_num_times)
65
+ if len(det_faces) == 0:
66
+ print('No face detected. Try to increase upsample_num_times.')
67
+ else:
68
+ if only_keep_largest:
69
+ print('Detect several faces and only keep the largest.')
70
+ face_areas = []
71
+ for i in range(len(det_faces)):
72
+ face_area = (det_faces[i].rect.right() - det_faces[i].rect.left()) * (
73
+ det_faces[i].rect.bottom() - det_faces[i].rect.top())
74
+ face_areas.append(face_area)
75
+ largest_idx = face_areas.index(max(face_areas))
76
+ self.det_faces = [det_faces[largest_idx]]
77
+ else:
78
+ self.det_faces = det_faces
79
+ return len(self.det_faces)
80
+
81
+ def get_face_landmarks_5(self):
82
+ for face in self.det_faces:
83
+ shape = self.shape_predictor_5(self.input_img, face.rect)
84
+ landmark = np.array([[part.x, part.y] for part in shape.parts()])
85
+ self.all_landmarks_5.append(landmark)
86
+ return len(self.all_landmarks_5)
87
+
88
+ def get_face_landmarks_68(self):
89
+ """Get 68 densemarks for cropped images.
90
+
91
+ Should only have one face at most in the cropped image.
92
+ """
93
+ num_detected_face = 0
94
+ for idx, face in enumerate(self.cropped_faces):
95
+ # face detection
96
+ det_face = self.face_detector(face, 1) # TODO: can we remove it?
97
+ if len(det_face) == 0:
98
+ print(f'Cannot find faces in cropped image with index {idx}.')
99
+ self.all_landmarks_68.append(None)
100
+ else:
101
+ if len(det_face) > 1:
102
+ print('Detect several faces in the cropped face. Use the '
103
+ ' largest one. Note that it will also cause overlap '
104
+ 'during paste_faces_to_input_image.')
105
+ face_areas = []
106
+ for i in range(len(det_face)):
107
+ face_area = (det_face[i].rect.right() - det_face[i].rect.left()) * (
108
+ det_face[i].rect.bottom() - det_face[i].rect.top())
109
+ face_areas.append(face_area)
110
+ largest_idx = face_areas.index(max(face_areas))
111
+ face_rect = det_face[largest_idx].rect
112
+ else:
113
+ face_rect = det_face[0].rect
114
+ shape = self.shape_predictor_68(face, face_rect)
115
+ landmark = np.array([[part.x, part.y] for part in shape.parts()])
116
+ self.all_landmarks_68.append(landmark)
117
+ num_detected_face += 1
118
+
119
+ return num_detected_face
120
+
121
+ def warp_crop_faces(self, save_cropped_path=None, save_inverse_affine_path=None):
122
+ """Get affine matrix, warp and cropped faces.
123
+
124
+ Also get inverse affine matrix for post-processing.
125
+ """
126
+ for idx, landmark in enumerate(self.all_landmarks_5):
127
+ # use 5 landmarks to get affine matrix
128
+ self.similarity_trans.estimate(landmark, self.face_template)
129
+ affine_matrix = self.similarity_trans.params[0:2, :]
130
+ self.affine_matrices.append(affine_matrix)
131
+ # warp and crop faces
132
+ cropped_face = cv2.warpAffine(self.input_img, affine_matrix, self.face_size)
133
+ self.cropped_faces.append(cropped_face)
134
+ # save the cropped face
135
+ if save_cropped_path is not None:
136
+ path, ext = os.path.splitext(save_cropped_path)
137
+ if self.save_png:
138
+ save_path = f'{path}_{idx:02d}.png'
139
+ else:
140
+ save_path = f'{path}_{idx:02d}{ext}'
141
+
142
+ imwrite(cv2.cvtColor(cropped_face, cv2.COLOR_RGB2BGR), save_path)
143
+
144
+ # get inverse affine matrix
145
+ self.similarity_trans.estimate(self.face_template, landmark * self.upscale_factor)
146
+ inverse_affine = self.similarity_trans.params[0:2, :]
147
+ self.inverse_affine_matrices.append(inverse_affine)
148
+ # save inverse affine matrices
149
+ if save_inverse_affine_path is not None:
150
+ path, _ = os.path.splitext(save_inverse_affine_path)
151
+ save_path = f'{path}_{idx:02d}.pth'
152
+ torch.save(inverse_affine, save_path)
153
+
154
+ def add_restored_face(self, face):
155
+ self.restored_faces.append(face)
156
+
157
+ def paste_faces_to_input_image(self, save_path):
158
+ # operate in the BGR order
159
+ input_img = cv2.cvtColor(self.input_img, cv2.COLOR_RGB2BGR)
160
+ h, w, _ = input_img.shape
161
+ h_up, w_up = h * self.upscale_factor, w * self.upscale_factor
162
+ # simply resize the background
163
+ upsample_img = cv2.resize(input_img, (w_up, h_up))
164
+ assert len(self.restored_faces) == len(
165
+ self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.')
166
+ for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices):
167
+ inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
168
+ mask = np.ones((*self.face_size, 3), dtype=np.float32)
169
+ inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
170
+ # remove the black borders
171
+ inv_mask_erosion = cv2.erode(inv_mask, np.ones((2 * self.upscale_factor, 2 * self.upscale_factor),
172
+ np.uint8))
173
+ inv_restored_remove_border = inv_mask_erosion * inv_restored
174
+ total_face_area = np.sum(inv_mask_erosion) // 3
175
+ # compute the fusion edge based on the area of face
176
+ w_edge = int(total_face_area**0.5) // 20
177
+ erosion_radius = w_edge * 2
178
+ inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
179
+ blur_size = w_edge * 2
180
+ inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
181
+ upsample_img = inv_soft_mask * inv_restored_remove_border + (1 - inv_soft_mask) * upsample_img
182
+ if self.save_png:
183
+ save_path = save_path.replace('.jpg', '.png').replace('.jpeg', '.png')
184
+ imwrite(upsample_img.astype(np.uint8), save_path)
185
+
186
+ def clean_all(self):
187
+ self.all_landmarks_5 = []
188
+ self.all_landmarks_68 = []
189
+ self.restored_faces = []
190
+ self.affine_matrices = []
191
+ self.cropped_faces = []
192
+ self.inverse_affine_matrices = []
basicsr/utils/file_client.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
2
+ from abc import ABCMeta, abstractmethod
3
+
4
+
5
+ class BaseStorageBackend(metaclass=ABCMeta):
6
+ """Abstract class of storage backends.
7
+
8
+ All backends need to implement two apis: ``get()`` and ``get_text()``.
9
+ ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
10
+ as texts.
11
+ """
12
+
13
+ @abstractmethod
14
+ def get(self, filepath):
15
+ pass
16
+
17
+ @abstractmethod
18
+ def get_text(self, filepath):
19
+ pass
20
+
21
+
22
+ class MemcachedBackend(BaseStorageBackend):
23
+ """Memcached storage backend.
24
+
25
+ Attributes:
26
+ server_list_cfg (str): Config file for memcached server list.
27
+ client_cfg (str): Config file for memcached client.
28
+ sys_path (str | None): Additional path to be appended to `sys.path`.
29
+ Default: None.
30
+ """
31
+
32
+ def __init__(self, server_list_cfg, client_cfg, sys_path=None):
33
+ if sys_path is not None:
34
+ import sys
35
+ sys.path.append(sys_path)
36
+ try:
37
+ import mc
38
+ except ImportError:
39
+ raise ImportError('Please install memcached to enable MemcachedBackend.')
40
+
41
+ self.server_list_cfg = server_list_cfg
42
+ self.client_cfg = client_cfg
43
+ self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
44
+ # mc.pyvector servers as a point which points to a memory cache
45
+ self._mc_buffer = mc.pyvector()
46
+
47
+ def get(self, filepath):
48
+ filepath = str(filepath)
49
+ import mc
50
+ self._client.Get(filepath, self._mc_buffer)
51
+ value_buf = mc.ConvertBuffer(self._mc_buffer)
52
+ return value_buf
53
+
54
+ def get_text(self, filepath):
55
+ raise NotImplementedError
56
+
57
+
58
+ class HardDiskBackend(BaseStorageBackend):
59
+ """Raw hard disks storage backend."""
60
+
61
+ def get(self, filepath):
62
+ filepath = str(filepath)
63
+ with open(filepath, 'rb') as f:
64
+ value_buf = f.read()
65
+ return value_buf
66
+
67
+ def get_text(self, filepath):
68
+ filepath = str(filepath)
69
+ with open(filepath, 'r') as f:
70
+ value_buf = f.read()
71
+ return value_buf
72
+
73
+
74
+ class LmdbBackend(BaseStorageBackend):
75
+ """Lmdb storage backend.
76
+
77
+ Args:
78
+ db_paths (str | list[str]): Lmdb database paths.
79
+ client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
80
+ readonly (bool, optional): Lmdb environment parameter. If True,
81
+ disallow any write operations. Default: True.
82
+ lock (bool, optional): Lmdb environment parameter. If False, when
83
+ concurrent access occurs, do not lock the database. Default: False.
84
+ readahead (bool, optional): Lmdb environment parameter. If False,
85
+ disable the OS filesystem readahead mechanism, which may improve
86
+ random read performance when a database is larger than RAM.
87
+ Default: False.
88
+
89
+ Attributes:
90
+ db_paths (list): Lmdb database path.
91
+ _client (list): A list of several lmdb envs.
92
+ """
93
+
94
+ def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
95
+ try:
96
+ import lmdb
97
+ except ImportError:
98
+ raise ImportError('Please install lmdb to enable LmdbBackend.')
99
+
100
+ if isinstance(client_keys, str):
101
+ client_keys = [client_keys]
102
+
103
+ if isinstance(db_paths, list):
104
+ self.db_paths = [str(v) for v in db_paths]
105
+ elif isinstance(db_paths, str):
106
+ self.db_paths = [str(db_paths)]
107
+ assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
108
+ f'but received {len(client_keys)} and {len(self.db_paths)}.')
109
+
110
+ self._client = {}
111
+ for client, path in zip(client_keys, self.db_paths):
112
+ self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
113
+
114
+ def get(self, filepath, client_key):
115
+ """Get values according to the filepath from one lmdb named client_key.
116
+
117
+ Args:
118
+ filepath (str | obj:`Path`): Here, filepath is the lmdb key.
119
+ client_key (str): Used for distinguishing different lmdb envs.
120
+ """
121
+ filepath = str(filepath)
122
+ assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.')
123
+ client = self._client[client_key]
124
+ with client.begin(write=False) as txn:
125
+ value_buf = txn.get(filepath.encode('ascii'))
126
+ return value_buf
127
+
128
+ def get_text(self, filepath):
129
+ raise NotImplementedError
130
+
131
+
132
+ class FileClient(object):
133
+ """A general file client to access files in different backend.
134
+
135
+ The client loads a file or text in a specified backend from its path
136
+ and return it as a binary file. it can also register other backend
137
+ accessor with a given name and backend class.
138
+
139
+ Attributes:
140
+ backend (str): The storage backend type. Options are "disk",
141
+ "memcached" and "lmdb".
142
+ client (:obj:`BaseStorageBackend`): The backend object.
143
+ """
144
+
145
+ _backends = {
146
+ 'disk': HardDiskBackend,
147
+ 'memcached': MemcachedBackend,
148
+ 'lmdb': LmdbBackend,
149
+ }
150
+
151
+ def __init__(self, backend='disk', **kwargs):
152
+ if backend not in self._backends:
153
+ raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
154
+ f' are {list(self._backends.keys())}')
155
+ self.backend = backend
156
+ self.client = self._backends[backend](**kwargs)
157
+
158
+ def get(self, filepath, client_key='default'):
159
+ # client_key is used only for lmdb, where different fileclients have
160
+ # different lmdb environments.
161
+ if self.backend == 'lmdb':
162
+ return self.client.get(filepath, client_key)
163
+ else:
164
+ return self.client.get(filepath)
165
+
166
+ def get_text(self, filepath):
167
+ return self.client.get_text(filepath)
basicsr/utils/flow_util.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501
2
+ import cv2
3
+ import numpy as np
4
+ import os
5
+
6
+
7
+ def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs):
8
+ """Read an optical flow map.
9
+
10
+ Args:
11
+ flow_path (ndarray or str): Flow path.
12
+ quantize (bool): whether to read quantized pair, if set to True,
13
+ remaining args will be passed to :func:`dequantize_flow`.
14
+ concat_axis (int): The axis that dx and dy are concatenated,
15
+ can be either 0 or 1. Ignored if quantize is False.
16
+
17
+ Returns:
18
+ ndarray: Optical flow represented as a (h, w, 2) numpy array
19
+ """
20
+ if quantize:
21
+ assert concat_axis in [0, 1]
22
+ cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED)
23
+ if cat_flow.ndim != 2:
24
+ raise IOError(f'{flow_path} is not a valid quantized flow file, its dimension is {cat_flow.ndim}.')
25
+ assert cat_flow.shape[concat_axis] % 2 == 0
26
+ dx, dy = np.split(cat_flow, 2, axis=concat_axis)
27
+ flow = dequantize_flow(dx, dy, *args, **kwargs)
28
+ else:
29
+ with open(flow_path, 'rb') as f:
30
+ try:
31
+ header = f.read(4).decode('utf-8')
32
+ except Exception:
33
+ raise IOError(f'Invalid flow file: {flow_path}')
34
+ else:
35
+ if header != 'PIEH':
36
+ raise IOError(f'Invalid flow file: {flow_path}, ' 'header does not contain PIEH')
37
+
38
+ w = np.fromfile(f, np.int32, 1).squeeze()
39
+ h = np.fromfile(f, np.int32, 1).squeeze()
40
+ flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2))
41
+
42
+ return flow.astype(np.float32)
43
+
44
+
45
+ def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
46
+ """Write optical flow to file.
47
+
48
+ If the flow is not quantized, it will be saved as a .flo file losslessly,
49
+ otherwise a jpeg image which is lossy but of much smaller size. (dx and dy
50
+ will be concatenated horizontally into a single image if quantize is True.)
51
+
52
+ Args:
53
+ flow (ndarray): (h, w, 2) array of optical flow.
54
+ filename (str): Output filepath.
55
+ quantize (bool): Whether to quantize the flow and save it to 2 jpeg
56
+ images. If set to True, remaining args will be passed to
57
+ :func:`quantize_flow`.
58
+ concat_axis (int): The axis that dx and dy are concatenated,
59
+ can be either 0 or 1. Ignored if quantize is False.
60
+ """
61
+ if not quantize:
62
+ with open(filename, 'wb') as f:
63
+ f.write('PIEH'.encode('utf-8'))
64
+ np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
65
+ flow = flow.astype(np.float32)
66
+ flow.tofile(f)
67
+ f.flush()
68
+ else:
69
+ assert concat_axis in [0, 1]
70
+ dx, dy = quantize_flow(flow, *args, **kwargs)
71
+ dxdy = np.concatenate((dx, dy), axis=concat_axis)
72
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
73
+ cv2.imwrite(filename, dxdy)
74
+
75
+
76
+ def quantize_flow(flow, max_val=0.02, norm=True):
77
+ """Quantize flow to [0, 255].
78
+
79
+ After this step, the size of flow will be much smaller, and can be
80
+ dumped as jpeg images.
81
+
82
+ Args:
83
+ flow (ndarray): (h, w, 2) array of optical flow.
84
+ max_val (float): Maximum value of flow, values beyond
85
+ [-max_val, max_val] will be truncated.
86
+ norm (bool): Whether to divide flow values by image width/height.
87
+
88
+ Returns:
89
+ tuple[ndarray]: Quantized dx and dy.
90
+ """
91
+ h, w, _ = flow.shape
92
+ dx = flow[..., 0]
93
+ dy = flow[..., 1]
94
+ if norm:
95
+ dx = dx / w # avoid inplace operations
96
+ dy = dy / h
97
+ # use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
98
+ flow_comps = [quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]]
99
+ return tuple(flow_comps)
100
+
101
+
102
+ def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
103
+ """Recover from quantized flow.
104
+
105
+ Args:
106
+ dx (ndarray): Quantized dx.
107
+ dy (ndarray): Quantized dy.
108
+ max_val (float): Maximum value used when quantizing.
109
+ denorm (bool): Whether to multiply flow values with width/height.
110
+
111
+ Returns:
112
+ ndarray: Dequantized flow.
113
+ """
114
+ assert dx.shape == dy.shape
115
+ assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
116
+
117
+ dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
118
+
119
+ if denorm:
120
+ dx *= dx.shape[1]
121
+ dy *= dx.shape[0]
122
+ flow = np.dstack((dx, dy))
123
+ return flow
124
+
125
+
126
+ def quantize(arr, min_val, max_val, levels, dtype=np.int64):
127
+ """Quantize an array of (-inf, inf) to [0, levels-1].
128
+
129
+ Args:
130
+ arr (ndarray): Input array.
131
+ min_val (scalar): Minimum value to be clipped.
132
+ max_val (scalar): Maximum value to be clipped.
133
+ levels (int): Quantization levels.
134
+ dtype (np.type): The type of the quantized array.
135
+
136
+ Returns:
137
+ tuple: Quantized array.
138
+ """
139
+ if not (isinstance(levels, int) and levels > 1):
140
+ raise ValueError(f'levels must be a positive integer, but got {levels}')
141
+ if min_val >= max_val:
142
+ raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})')
143
+
144
+ arr = np.clip(arr, min_val, max_val) - min_val
145
+ quantized_arr = np.minimum(np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
146
+
147
+ return quantized_arr
148
+
149
+
150
+ def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
151
+ """Dequantize an array.
152
+
153
+ Args:
154
+ arr (ndarray): Input array.
155
+ min_val (scalar): Minimum value to be clipped.
156
+ max_val (scalar): Maximum value to be clipped.
157
+ levels (int): Quantization levels.
158
+ dtype (np.type): The type of the dequantized array.
159
+
160
+ Returns:
161
+ tuple: Dequantized array.
162
+ """
163
+ if not (isinstance(levels, int) and levels > 1):
164
+ raise ValueError(f'levels must be a positive integer, but got {levels}')
165
+ if min_val >= max_val:
166
+ raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})')
167
+
168
+ dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - min_val) / levels + min_val
169
+
170
+ return dequantized_arr
basicsr/utils/img_process_util.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ from torch.nn import functional as F
5
+
6
+
7
+ def filter2D(img, kernel):
8
+ """PyTorch version of cv2.filter2D
9
+
10
+ Args:
11
+ img (Tensor): (b, c, h, w)
12
+ kernel (Tensor): (b, k, k)
13
+ """
14
+ k = kernel.size(-1)
15
+ b, c, h, w = img.size()
16
+ if k % 2 == 1:
17
+ img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect')
18
+ else:
19
+ raise ValueError('Wrong kernel size')
20
+
21
+ ph, pw = img.size()[-2:]
22
+
23
+ if kernel.size(0) == 1:
24
+ # apply the same kernel to all batch images
25
+ img = img.view(b * c, 1, ph, pw)
26
+ kernel = kernel.view(1, 1, k, k)
27
+ return F.conv2d(img, kernel, padding=0).view(b, c, h, w)
28
+ else:
29
+ img = img.view(1, b * c, ph, pw)
30
+ kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k)
31
+ return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w)
32
+
33
+
34
+ def usm_sharp(img, weight=0.5, radius=50, threshold=10):
35
+ """USM sharpening.
36
+
37
+ Input image: I; Blurry image: B.
38
+ 1. sharp = I + weight * (I - B)
39
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
40
+ 3. Blur mask:
41
+ 4. Out = Mask * sharp + (1 - Mask) * I
42
+
43
+
44
+ Args:
45
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
46
+ weight (float): Sharp weight. Default: 1.
47
+ radius (float): Kernel size of Gaussian blur. Default: 50.
48
+ threshold (int):
49
+ """
50
+ if radius % 2 == 0:
51
+ radius += 1
52
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
53
+ residual = img - blur
54
+ mask = np.abs(residual) * 255 > threshold
55
+ mask = mask.astype('float32')
56
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
57
+
58
+ sharp = img + weight * residual
59
+ sharp = np.clip(sharp, 0, 1)
60
+ return soft_mask * sharp + (1 - soft_mask) * img
61
+
62
+
63
+ class USMSharp(torch.nn.Module):
64
+
65
+ def __init__(self, radius=50, sigma=0):
66
+ super(USMSharp, self).__init__()
67
+ if radius % 2 == 0:
68
+ radius += 1
69
+ self.radius = radius
70
+ kernel = cv2.getGaussianKernel(radius, sigma)
71
+ kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0)
72
+ self.register_buffer('kernel', kernel)
73
+
74
+ def forward(self, img, weight=0.5, threshold=10):
75
+ blur = filter2D(img, self.kernel)
76
+ residual = img - blur
77
+
78
+ mask = torch.abs(residual) * 255 > threshold
79
+ mask = mask.float()
80
+ soft_mask = filter2D(mask, self.kernel)
81
+ sharp = img + weight * residual
82
+ sharp = torch.clip(sharp, 0, 1)
83
+ return soft_mask * sharp + (1 - soft_mask) * img
basicsr/utils/img_util.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ import os
5
+ import torch
6
+ from torchvision.utils import make_grid
7
+
8
+
9
+ def img2tensor(imgs, bgr2rgb=True, float32=True):
10
+ """Numpy array to tensor.
11
+
12
+ Args:
13
+ imgs (list[ndarray] | ndarray): Input images.
14
+ bgr2rgb (bool): Whether to change bgr to rgb.
15
+ float32 (bool): Whether to change to float32.
16
+
17
+ Returns:
18
+ list[tensor] | tensor: Tensor images. If returned results only have
19
+ one element, just return tensor.
20
+ """
21
+
22
+ def _totensor(img, bgr2rgb, float32):
23
+ if img.shape[2] == 3 and bgr2rgb:
24
+ if img.dtype == 'float64':
25
+ img = img.astype('float32')
26
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
27
+ img = torch.from_numpy(img.transpose(2, 0, 1))
28
+ if float32:
29
+ img = img.float()
30
+ return img
31
+
32
+ if isinstance(imgs, list):
33
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
34
+ else:
35
+ return _totensor(imgs, bgr2rgb, float32)
36
+
37
+
38
+ def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
39
+ """Convert torch Tensors into image numpy arrays.
40
+
41
+ After clamping to [min, max], values will be normalized to [0, 1].
42
+
43
+ Args:
44
+ tensor (Tensor or list[Tensor]): Accept shapes:
45
+ 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
46
+ 2) 3D Tensor of shape (3/1 x H x W);
47
+ 3) 2D Tensor of shape (H x W).
48
+ Tensor channel should be in RGB order.
49
+ rgb2bgr (bool): Whether to change rgb to bgr.
50
+ out_type (numpy type): output types. If ``np.uint8``, transform outputs
51
+ to uint8 type with range [0, 255]; otherwise, float type with
52
+ range [0, 1]. Default: ``np.uint8``.
53
+ min_max (tuple[int]): min and max values for clamp.
54
+
55
+ Returns:
56
+ (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
57
+ shape (H x W). The channel order is BGR.
58
+ """
59
+ if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
60
+ raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
61
+
62
+ if torch.is_tensor(tensor):
63
+ tensor = [tensor]
64
+ result = []
65
+ for _tensor in tensor:
66
+ _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
67
+ _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
68
+
69
+ n_dim = _tensor.dim()
70
+ if n_dim == 4:
71
+ img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
72
+ img_np = img_np.transpose(1, 2, 0)
73
+ if rgb2bgr:
74
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
75
+ elif n_dim == 3:
76
+ img_np = _tensor.numpy()
77
+ img_np = img_np.transpose(1, 2, 0)
78
+ if img_np.shape[2] == 1: # gray image
79
+ img_np = np.squeeze(img_np, axis=2)
80
+ else:
81
+ if rgb2bgr:
82
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
83
+ elif n_dim == 2:
84
+ img_np = _tensor.numpy()
85
+ else:
86
+ raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
87
+ if out_type == np.uint8:
88
+ # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
89
+ img_np = (img_np * 255.0).round()
90
+ img_np = img_np.astype(out_type)
91
+ result.append(img_np)
92
+ if len(result) == 1:
93
+ result = result[0]
94
+ return result
95
+
96
+
97
+ def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
98
+ """This implementation is slightly faster than tensor2img.
99
+ It now only supports torch tensor with shape (1, c, h, w).
100
+
101
+ Args:
102
+ tensor (Tensor): Now only support torch tensor with (1, c, h, w).
103
+ rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
104
+ min_max (tuple[int]): min and max values for clamp.
105
+ """
106
+ output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
107
+ output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
108
+ output = output.type(torch.uint8).cpu().numpy()
109
+ if rgb2bgr:
110
+ output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
111
+ return output
112
+
113
+
114
+ def imfrombytes(content, flag='color', float32=False):
115
+ """Read an image from bytes.
116
+
117
+ Args:
118
+ content (bytes): Image bytes got from files or other streams.
119
+ flag (str): Flags specifying the color type of a loaded image,
120
+ candidates are `color`, `grayscale` and `unchanged`.
121
+ float32 (bool): Whether to change to float32., If True, will also norm
122
+ to [0, 1]. Default: False.
123
+
124
+ Returns:
125
+ ndarray: Loaded image array.
126
+ """
127
+ img_np = np.frombuffer(content, np.uint8)
128
+ imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
129
+ img = cv2.imdecode(img_np, imread_flags[flag])
130
+ if float32:
131
+ img = img.astype(np.float32) / 255.
132
+ return img
133
+
134
+
135
+ def imwrite(img, file_path, params=None, auto_mkdir=True):
136
+ """Write image to file.
137
+
138
+ Args:
139
+ img (ndarray): Image array to be written.
140
+ file_path (str): Image file path.
141
+ params (None or list): Same as opencv's :func:`imwrite` interface.
142
+ auto_mkdir (bool): If the parent folder of `file_path` does not exist,
143
+ whether to create it automatically.
144
+
145
+ Returns:
146
+ bool: Successful or not.
147
+ """
148
+ if auto_mkdir:
149
+ dir_name = os.path.abspath(os.path.dirname(file_path))
150
+ os.makedirs(dir_name, exist_ok=True)
151
+ ok = cv2.imwrite(file_path, img, params)
152
+ if not ok:
153
+ raise IOError('Failed in writing images.')
154
+
155
+
156
+ def crop_border(imgs, crop_border):
157
+ """Crop borders of images.
158
+
159
+ Args:
160
+ imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
161
+ crop_border (int): Crop border for each end of height and weight.
162
+
163
+ Returns:
164
+ list[ndarray]: Cropped images.
165
+ """
166
+ if crop_border == 0:
167
+ return imgs
168
+ else:
169
+ if isinstance(imgs, list):
170
+ return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
171
+ else:
172
+ return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
173
+
174
+
175
+ def tensor_lab2rgb(labs, illuminant="D65", observer="2"):
176
+ """
177
+ Args:
178
+ lab : (B, C, H, W)
179
+ Returns:
180
+ tuple : (C, H, W)
181
+ """
182
+ illuminants = \
183
+ {"A": {'2': (1.098466069456375, 1, 0.3558228003436005),
184
+ '10': (1.111420406956693, 1, 0.3519978321919493)},
185
+ "D50": {'2': (0.9642119944211994, 1, 0.8251882845188288),
186
+ '10': (0.9672062750333777, 1, 0.8142801513128616)},
187
+ "D55": {'2': (0.956797052643698, 1, 0.9214805860173273),
188
+ '10': (0.9579665682254781, 1, 0.9092525159847462)},
189
+ "D65": {'2': (0.95047, 1., 1.08883), # This was: `lab_ref_white`
190
+ '10': (0.94809667673716, 1, 1.0730513595166162)},
191
+ "D75": {'2': (0.9497220898840717, 1, 1.226393520724154),
192
+ '10': (0.9441713925645873, 1, 1.2064272211720228)},
193
+ "E": {'2': (1.0, 1.0, 1.0),
194
+ '10': (1.0, 1.0, 1.0)}}
195
+ xyz_from_rgb = np.array([[0.412453, 0.357580, 0.180423], [0.212671, 0.715160, 0.072169],
196
+ [0.019334, 0.119193, 0.950227]])
197
+
198
+ rgb_from_xyz = np.array([[3.240481340, -0.96925495, 0.055646640], [-1.53715152, 1.875990000, -0.20404134],
199
+ [-0.49853633, 0.041555930, 1.057311070]])
200
+ B, C, H, W = labs.shape
201
+ arrs = labs.permute((0, 2, 3, 1)).contiguous() # (B, 3, H, W) -> (B, H, W, 3)
202
+ L, a, b = arrs[:, :, :, 0:1], arrs[:, :, :, 1:2], arrs[:, :, :, 2:]
203
+ y = (L + 16.) / 116.
204
+ x = (a / 500.) + y
205
+ z = y - (b / 200.)
206
+ invalid = z.data < 0
207
+ z[invalid] = 0
208
+ xyz = torch.cat([x, y, z], dim=3)
209
+ mask = xyz.data > 0.2068966
210
+ mask_xyz = xyz.clone()
211
+ mask_xyz[mask] = torch.pow(xyz[mask], 3.0)
212
+ mask_xyz[~mask] = (xyz[~mask] - 16.0 / 116.) / 7.787
213
+ xyz_ref_white = illuminants[illuminant][observer]
214
+ for i in range(C):
215
+ mask_xyz[:, :, :, i] = mask_xyz[:, :, :, i] * xyz_ref_white[i]
216
+
217
+ rgb_trans = torch.mm(mask_xyz.view(-1, 3), torch.from_numpy(rgb_from_xyz).type_as(xyz)).view(B, H, W, C)
218
+ rgb = rgb_trans.permute((0, 3, 1, 2)).contiguous()
219
+ mask = rgb.data > 0.0031308
220
+ mask_rgb = rgb.clone()
221
+ mask_rgb[mask] = 1.055 * torch.pow(rgb[mask], 1 / 2.4) - 0.055
222
+ mask_rgb[~mask] = rgb[~mask] * 12.92
223
+ neg_mask = mask_rgb.data < 0
224
+ large_mask = mask_rgb.data > 1
225
+ mask_rgb[neg_mask] = 0
226
+ mask_rgb[large_mask] = 1
227
+ return mask_rgb
basicsr/utils/lmdb_util.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import lmdb
3
+ import sys
4
+ from multiprocessing import Pool
5
+ from os import path as osp
6
+ from tqdm import tqdm
7
+
8
+
9
+ def make_lmdb_from_imgs(data_path,
10
+ lmdb_path,
11
+ img_path_list,
12
+ keys,
13
+ batch=5000,
14
+ compress_level=1,
15
+ multiprocessing_read=False,
16
+ n_thread=40,
17
+ map_size=None):
18
+ """Make lmdb from images.
19
+
20
+ Contents of lmdb. The file structure is:
21
+ example.lmdb
22
+ ├── data.mdb
23
+ ├── lock.mdb
24
+ ├── meta_info.txt
25
+
26
+ The data.mdb and lock.mdb are standard lmdb files and you can refer to
27
+ https://lmdb.readthedocs.io/en/release/ for more details.
28
+
29
+ The meta_info.txt is a specified txt file to record the meta information
30
+ of our datasets. It will be automatically created when preparing
31
+ datasets by our provided dataset tools.
32
+ Each line in the txt file records 1)image name (with extension),
33
+ 2)image shape, and 3)compression level, separated by a white space.
34
+
35
+ For example, the meta information could be:
36
+ `000_00000000.png (720,1280,3) 1`, which means:
37
+ 1) image name (with extension): 000_00000000.png;
38
+ 2) image shape: (720,1280,3);
39
+ 3) compression level: 1
40
+
41
+ We use the image name without extension as the lmdb key.
42
+
43
+ If `multiprocessing_read` is True, it will read all the images to memory
44
+ using multiprocessing. Thus, your server needs to have enough memory.
45
+
46
+ Args:
47
+ data_path (str): Data path for reading images.
48
+ lmdb_path (str): Lmdb save path.
49
+ img_path_list (str): Image path list.
50
+ keys (str): Used for lmdb keys.
51
+ batch (int): After processing batch images, lmdb commits.
52
+ Default: 5000.
53
+ compress_level (int): Compress level when encoding images. Default: 1.
54
+ multiprocessing_read (bool): Whether use multiprocessing to read all
55
+ the images to memory. Default: False.
56
+ n_thread (int): For multiprocessing.
57
+ map_size (int | None): Map size for lmdb env. If None, use the
58
+ estimated size from images. Default: None
59
+ """
60
+
61
+ assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, '
62
+ f'but got {len(img_path_list)} and {len(keys)}')
63
+ print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
64
+ print(f'Totoal images: {len(img_path_list)}')
65
+ if not lmdb_path.endswith('.lmdb'):
66
+ raise ValueError("lmdb_path must end with '.lmdb'.")
67
+ if osp.exists(lmdb_path):
68
+ print(f'Folder {lmdb_path} already exists. Exit.')
69
+ sys.exit(1)
70
+
71
+ if multiprocessing_read:
72
+ # read all the images to memory (multiprocessing)
73
+ dataset = {} # use dict to keep the order for multiprocessing
74
+ shapes = {}
75
+ print(f'Read images with multiprocessing, #thread: {n_thread} ...')
76
+ pbar = tqdm(total=len(img_path_list), unit='image')
77
+
78
+ def callback(arg):
79
+ """get the image data and update pbar."""
80
+ key, dataset[key], shapes[key] = arg
81
+ pbar.update(1)
82
+ pbar.set_description(f'Read {key}')
83
+
84
+ pool = Pool(n_thread)
85
+ for path, key in zip(img_path_list, keys):
86
+ pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback)
87
+ pool.close()
88
+ pool.join()
89
+ pbar.close()
90
+ print(f'Finish reading {len(img_path_list)} images.')
91
+
92
+ # create lmdb environment
93
+ if map_size is None:
94
+ # obtain data size for one image
95
+ img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
96
+ _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
97
+ data_size_per_img = img_byte.nbytes
98
+ print('Data size per image is: ', data_size_per_img)
99
+ data_size = data_size_per_img * len(img_path_list)
100
+ map_size = data_size * 10
101
+
102
+ env = lmdb.open(lmdb_path, map_size=map_size)
103
+
104
+ # write data to lmdb
105
+ pbar = tqdm(total=len(img_path_list), unit='chunk')
106
+ txn = env.begin(write=True)
107
+ txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
108
+ for idx, (path, key) in enumerate(zip(img_path_list, keys)):
109
+ pbar.update(1)
110
+ pbar.set_description(f'Write {key}')
111
+ key_byte = key.encode('ascii')
112
+ if multiprocessing_read:
113
+ img_byte = dataset[key]
114
+ h, w, c = shapes[key]
115
+ else:
116
+ _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level)
117
+ h, w, c = img_shape
118
+
119
+ txn.put(key_byte, img_byte)
120
+ # write meta information
121
+ txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
122
+ if idx % batch == 0:
123
+ txn.commit()
124
+ txn = env.begin(write=True)
125
+ pbar.close()
126
+ txn.commit()
127
+ env.close()
128
+ txt_file.close()
129
+ print('\nFinish writing lmdb.')
130
+
131
+
132
+ def read_img_worker(path, key, compress_level):
133
+ """Read image worker.
134
+
135
+ Args:
136
+ path (str): Image path.
137
+ key (str): Image key.
138
+ compress_level (int): Compress level when encoding images.
139
+
140
+ Returns:
141
+ str: Image key.
142
+ byte: Image byte.
143
+ tuple[int]: Image shape.
144
+ """
145
+
146
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
147
+ if img.ndim == 2:
148
+ h, w = img.shape
149
+ c = 1
150
+ else:
151
+ h, w, c = img.shape
152
+ _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
153
+ return (key, img_byte, (h, w, c))
154
+
155
+
156
+ class LmdbMaker():
157
+ """LMDB Maker.
158
+
159
+ Args:
160
+ lmdb_path (str): Lmdb save path.
161
+ map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
162
+ batch (int): After processing batch images, lmdb commits.
163
+ Default: 5000.
164
+ compress_level (int): Compress level when encoding images. Default: 1.
165
+ """
166
+
167
+ def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1):
168
+ if not lmdb_path.endswith('.lmdb'):
169
+ raise ValueError("lmdb_path must end with '.lmdb'.")
170
+ if osp.exists(lmdb_path):
171
+ print(f'Folder {lmdb_path} already exists. Exit.')
172
+ sys.exit(1)
173
+
174
+ self.lmdb_path = lmdb_path
175
+ self.batch = batch
176
+ self.compress_level = compress_level
177
+ self.env = lmdb.open(lmdb_path, map_size=map_size)
178
+ self.txn = self.env.begin(write=True)
179
+ self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
180
+ self.counter = 0
181
+
182
+ def put(self, img_byte, key, img_shape):
183
+ self.counter += 1
184
+ key_byte = key.encode('ascii')
185
+ self.txn.put(key_byte, img_byte)
186
+ # write meta information
187
+ h, w, c = img_shape
188
+ self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
189
+ if self.counter % self.batch == 0:
190
+ self.txn.commit()
191
+ self.txn = self.env.begin(write=True)
192
+
193
+ def close(self):
194
+ self.txn.commit()
195
+ self.env.close()
196
+ self.txt_file.close()
basicsr/utils/logger.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import time
4
+
5
+ from .dist_util import get_dist_info, master_only
6
+
7
+ initialized_logger = {}
8
+
9
+
10
+ class AvgTimer():
11
+
12
+ def __init__(self, window=200):
13
+ self.window = window # average window
14
+ self.current_time = 0
15
+ self.total_time = 0
16
+ self.count = 0
17
+ self.avg_time = 0
18
+ self.start()
19
+
20
+ def start(self):
21
+ self.start_time = time.time()
22
+
23
+ def record(self):
24
+ self.count += 1
25
+ self.current_time = time.time() - self.start_time
26
+ self.total_time += self.current_time
27
+ # calculate average time
28
+ self.avg_time = self.total_time / self.count
29
+ # reset
30
+ if self.count > self.window:
31
+ self.count = 0
32
+ self.total_time = 0
33
+
34
+ def get_current_time(self):
35
+ return self.current_time
36
+
37
+ def get_avg_time(self):
38
+ return self.avg_time
39
+
40
+
41
+ class MessageLogger():
42
+ """Message logger for printing.
43
+
44
+ Args:
45
+ opt (dict): Config. It contains the following keys:
46
+ name (str): Exp name.
47
+ logger (dict): Contains 'print_freq' (str) for logger interval.
48
+ train (dict): Contains 'total_iter' (int) for total iters.
49
+ use_tb_logger (bool): Use tensorboard logger.
50
+ start_iter (int): Start iter. Default: 1.
51
+ tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
52
+ """
53
+
54
+ def __init__(self, opt, start_iter=1, tb_logger=None):
55
+ self.exp_name = opt['name']
56
+ self.interval = opt['logger']['print_freq']
57
+ self.start_iter = start_iter
58
+ self.max_iters = opt['train']['total_iter']
59
+ self.use_tb_logger = opt['logger']['use_tb_logger']
60
+ self.tb_logger = tb_logger
61
+ self.start_time = time.time()
62
+ self.logger = get_root_logger()
63
+
64
+ def reset_start_time(self):
65
+ self.start_time = time.time()
66
+
67
+ @master_only
68
+ def __call__(self, log_vars):
69
+ """Format logging message.
70
+
71
+ Args:
72
+ log_vars (dict): It contains the following keys:
73
+ epoch (int): Epoch number.
74
+ iter (int): Current iter.
75
+ lrs (list): List for learning rates.
76
+
77
+ time (float): Iter time.
78
+ data_time (float): Data time for each iter.
79
+ """
80
+ # epoch, iter, learning rates
81
+ epoch = log_vars.pop('epoch')
82
+ current_iter = log_vars.pop('iter')
83
+ lrs = log_vars.pop('lrs')
84
+
85
+ message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, iter:{current_iter:8,d}, lr:(')
86
+ for v in lrs:
87
+ message += f'{v:.3e},'
88
+ message += ')] '
89
+
90
+ # time and estimated time
91
+ if 'time' in log_vars.keys():
92
+ iter_time = log_vars.pop('time')
93
+ data_time = log_vars.pop('data_time')
94
+
95
+ total_time = time.time() - self.start_time
96
+ time_sec_avg = total_time / (current_iter - self.start_iter + 1)
97
+ eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
98
+ eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
99
+ message += f'[eta: {eta_str}, '
100
+ message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
101
+
102
+ # other items, especially losses
103
+ for k, v in log_vars.items():
104
+ message += f'{k}: {v:.4e} '
105
+ # tensorboard logger
106
+ if self.use_tb_logger and 'debug' not in self.exp_name:
107
+ if k.startswith('l_'):
108
+ self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
109
+ else:
110
+ self.tb_logger.add_scalar(k, v, current_iter)
111
+ self.logger.info(message)
112
+
113
+
114
+ @master_only
115
+ def init_tb_logger(log_dir):
116
+ from torch.utils.tensorboard import SummaryWriter
117
+ tb_logger = SummaryWriter(log_dir=log_dir)
118
+ return tb_logger
119
+
120
+
121
+ @master_only
122
+ def init_wandb_logger(opt):
123
+ """We now only use wandb to sync tensorboard log."""
124
+ import wandb
125
+ logger = get_root_logger()
126
+
127
+ project = opt['logger']['wandb']['project']
128
+ resume_id = opt['logger']['wandb'].get('resume_id')
129
+ if resume_id:
130
+ wandb_id = resume_id
131
+ resume = 'allow'
132
+ logger.warning(f'Resume wandb logger with id={wandb_id}.')
133
+ else:
134
+ wandb_id = wandb.util.generate_id()
135
+ resume = 'never'
136
+
137
+ wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True)
138
+
139
+ logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')
140
+
141
+
142
+ def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
143
+ """Get the root logger.
144
+
145
+ The logger will be initialized if it has not been initialized. By default a
146
+ StreamHandler will be added. If `log_file` is specified, a FileHandler will
147
+ also be added.
148
+
149
+ Args:
150
+ logger_name (str): root logger name. Default: 'basicsr'.
151
+ log_file (str | None): The log filename. If specified, a FileHandler
152
+ will be added to the root logger.
153
+ log_level (int): The root logger level. Note that only the process of
154
+ rank 0 is affected, while other processes will set the level to
155
+ "Error" and be silent most of the time.
156
+
157
+ Returns:
158
+ logging.Logger: The root logger.
159
+ """
160
+ logger = logging.getLogger(logger_name)
161
+ # if the logger has been initialized, just return it
162
+ if logger_name in initialized_logger:
163
+ return logger
164
+
165
+ format_str = '%(asctime)s %(levelname)s: %(message)s'
166
+ stream_handler = logging.StreamHandler()
167
+ stream_handler.setFormatter(logging.Formatter(format_str))
168
+ logger.addHandler(stream_handler)
169
+ logger.propagate = False
170
+ rank, _ = get_dist_info()
171
+ if rank != 0:
172
+ logger.setLevel('ERROR')
173
+ elif log_file is not None:
174
+ logger.setLevel(log_level)
175
+ # add file handler
176
+ file_handler = logging.FileHandler(log_file, 'w')
177
+ file_handler.setFormatter(logging.Formatter(format_str))
178
+ file_handler.setLevel(log_level)
179
+ logger.addHandler(file_handler)
180
+ initialized_logger[logger_name] = True
181
+ return logger
182
+
183
+
184
+ def get_env_info():
185
+ """Get environment information.
186
+
187
+ Currently, only log the software version.
188
+ """
189
+ import torch
190
+ import torchvision
191
+
192
+ from basicsr.version import __version__
193
+ msg = r"""
194
+ ____ _ _____ ____
195
+ / __ ) ____ _ _____ (_)_____/ ___/ / __ \
196
+ / __ |/ __ `// ___// // ___/\__ \ / /_/ /
197
+ / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/
198
+ /_____/ \__,_//____//_/ \___//____//_/ |_|
199
+ ______ __ __ __ __
200
+ / ____/____ ____ ____/ / / / __ __ _____ / /__ / /
201
+ / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / /
202
+ / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/
203
+ \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_)
204
+ """
205
+ msg += ('\nVersion Information: '
206
+ f'\n\tBasicSR: {__version__}'
207
+ f'\n\tPyTorch: {torch.__version__}'
208
+ f'\n\tTorchVision: {torchvision.__version__}')
209
+ return msg
basicsr/utils/matlab_functions.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+
5
+
6
+ def cubic(x):
7
+ """cubic function used for calculate_weights_indices."""
8
+ absx = torch.abs(x)
9
+ absx2 = absx**2
10
+ absx3 = absx**3
11
+ return (1.5 * absx3 - 2.5 * absx2 + 1) * (
12
+ (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) *
13
+ (absx <= 2)).type_as(absx))
14
+
15
+
16
+ def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
17
+ """Calculate weights and indices, used for imresize function.
18
+
19
+ Args:
20
+ in_length (int): Input length.
21
+ out_length (int): Output length.
22
+ scale (float): Scale factor.
23
+ kernel_width (int): Kernel width.
24
+ antialisaing (bool): Whether to apply anti-aliasing when downsampling.
25
+ """
26
+
27
+ if (scale < 1) and antialiasing:
28
+ # Use a modified kernel (larger kernel width) to simultaneously
29
+ # interpolate and antialias
30
+ kernel_width = kernel_width / scale
31
+
32
+ # Output-space coordinates
33
+ x = torch.linspace(1, out_length, out_length)
34
+
35
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
36
+ # in output space maps to 0.5 in input space, and 0.5 + scale in output
37
+ # space maps to 1.5 in input space.
38
+ u = x / scale + 0.5 * (1 - 1 / scale)
39
+
40
+ # What is the left-most pixel that can be involved in the computation?
41
+ left = torch.floor(u - kernel_width / 2)
42
+
43
+ # What is the maximum number of pixels that can be involved in the
44
+ # computation? Note: it's OK to use an extra pixel here; if the
45
+ # corresponding weights are all zero, it will be eliminated at the end
46
+ # of this function.
47
+ p = math.ceil(kernel_width) + 2
48
+
49
+ # The indices of the input pixels involved in computing the k-th output
50
+ # pixel are in row k of the indices matrix.
51
+ indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
52
+ out_length, p)
53
+
54
+ # The weights used to compute the k-th output pixel are in row k of the
55
+ # weights matrix.
56
+ distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
57
+
58
+ # apply cubic kernel
59
+ if (scale < 1) and antialiasing:
60
+ weights = scale * cubic(distance_to_center * scale)
61
+ else:
62
+ weights = cubic(distance_to_center)
63
+
64
+ # Normalize the weights matrix so that each row sums to 1.
65
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
66
+ weights = weights / weights_sum.expand(out_length, p)
67
+
68
+ # If a column in weights is all zero, get rid of it. only consider the
69
+ # first and last column.
70
+ weights_zero_tmp = torch.sum((weights == 0), 0)
71
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
72
+ indices = indices.narrow(1, 1, p - 2)
73
+ weights = weights.narrow(1, 1, p - 2)
74
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
75
+ indices = indices.narrow(1, 0, p - 2)
76
+ weights = weights.narrow(1, 0, p - 2)
77
+ weights = weights.contiguous()
78
+ indices = indices.contiguous()
79
+ sym_len_s = -indices.min() + 1
80
+ sym_len_e = indices.max() - in_length
81
+ indices = indices + sym_len_s - 1
82
+ return weights, indices, int(sym_len_s), int(sym_len_e)
83
+
84
+
85
+ @torch.no_grad()
86
+ def imresize(img, scale, antialiasing=True):
87
+ """imresize function same as MATLAB.
88
+
89
+ It now only supports bicubic.
90
+ The same scale applies for both height and width.
91
+
92
+ Args:
93
+ img (Tensor | Numpy array):
94
+ Tensor: Input image with shape (c, h, w), [0, 1] range.
95
+ Numpy: Input image with shape (h, w, c), [0, 1] range.
96
+ scale (float): Scale factor. The same scale applies for both height
97
+ and width.
98
+ antialisaing (bool): Whether to apply anti-aliasing when downsampling.
99
+ Default: True.
100
+
101
+ Returns:
102
+ Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
103
+ """
104
+ squeeze_flag = False
105
+ if type(img).__module__ == np.__name__: # numpy type
106
+ numpy_type = True
107
+ if img.ndim == 2:
108
+ img = img[:, :, None]
109
+ squeeze_flag = True
110
+ img = torch.from_numpy(img.transpose(2, 0, 1)).float()
111
+ else:
112
+ numpy_type = False
113
+ if img.ndim == 2:
114
+ img = img.unsqueeze(0)
115
+ squeeze_flag = True
116
+
117
+ in_c, in_h, in_w = img.size()
118
+ out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale)
119
+ kernel_width = 4
120
+ kernel = 'cubic'
121
+
122
+ # get weights and indices
123
+ weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width,
124
+ antialiasing)
125
+ weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width,
126
+ antialiasing)
127
+ # process H dimension
128
+ # symmetric copying
129
+ img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
130
+ img_aug.narrow(1, sym_len_hs, in_h).copy_(img)
131
+
132
+ sym_patch = img[:, :sym_len_hs, :]
133
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
134
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
135
+ img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
136
+
137
+ sym_patch = img[:, -sym_len_he:, :]
138
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
139
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
140
+ img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
141
+
142
+ out_1 = torch.FloatTensor(in_c, out_h, in_w)
143
+ kernel_width = weights_h.size(1)
144
+ for i in range(out_h):
145
+ idx = int(indices_h[i][0])
146
+ for j in range(in_c):
147
+ out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
148
+
149
+ # process W dimension
150
+ # symmetric copying
151
+ out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
152
+ out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
153
+
154
+ sym_patch = out_1[:, :, :sym_len_ws]
155
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
156
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
157
+ out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
158
+
159
+ sym_patch = out_1[:, :, -sym_len_we:]
160
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
161
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
162
+ out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
163
+
164
+ out_2 = torch.FloatTensor(in_c, out_h, out_w)
165
+ kernel_width = weights_w.size(1)
166
+ for i in range(out_w):
167
+ idx = int(indices_w[i][0])
168
+ for j in range(in_c):
169
+ out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
170
+
171
+ if squeeze_flag:
172
+ out_2 = out_2.squeeze(0)
173
+ if numpy_type:
174
+ out_2 = out_2.numpy()
175
+ if not squeeze_flag:
176
+ out_2 = out_2.transpose(1, 2, 0)
177
+
178
+ return out_2
179
+
180
+
181
+ def rgb2ycbcr(img, y_only=False):
182
+ """Convert a RGB image to YCbCr image.
183
+
184
+ This function produces the same results as Matlab's `rgb2ycbcr` function.
185
+ It implements the ITU-R BT.601 conversion for standard-definition
186
+ television. See more details in
187
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
188
+
189
+ It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
190
+ In OpenCV, it implements a JPEG conversion. See more details in
191
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
192
+
193
+ Args:
194
+ img (ndarray): The input image. It accepts:
195
+ 1. np.uint8 type with range [0, 255];
196
+ 2. np.float32 type with range [0, 1].
197
+ y_only (bool): Whether to only return Y channel. Default: False.
198
+
199
+ Returns:
200
+ ndarray: The converted YCbCr image. The output image has the same type
201
+ and range as input image.
202
+ """
203
+ img_type = img.dtype
204
+ img = _convert_input_type_range(img)
205
+ if y_only:
206
+ out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
207
+ else:
208
+ out_img = np.matmul(
209
+ img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128]
210
+ out_img = _convert_output_type_range(out_img, img_type)
211
+ return out_img
212
+
213
+
214
+ def bgr2ycbcr(img, y_only=False):
215
+ """Convert a BGR image to YCbCr image.
216
+
217
+ The bgr version of rgb2ycbcr.
218
+ It implements the ITU-R BT.601 conversion for standard-definition
219
+ television. See more details in
220
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
221
+
222
+ It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
223
+ In OpenCV, it implements a JPEG conversion. See more details in
224
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
225
+
226
+ Args:
227
+ img (ndarray): The input image. It accepts:
228
+ 1. np.uint8 type with range [0, 255];
229
+ 2. np.float32 type with range [0, 1].
230
+ y_only (bool): Whether to only return Y channel. Default: False.
231
+
232
+ Returns:
233
+ ndarray: The converted YCbCr image. The output image has the same type
234
+ and range as input image.
235
+ """
236
+ img_type = img.dtype
237
+ img = _convert_input_type_range(img)
238
+ if y_only:
239
+ out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
240
+ else:
241
+ out_img = np.matmul(
242
+ img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
243
+ out_img = _convert_output_type_range(out_img, img_type)
244
+ return out_img
245
+
246
+
247
+ def ycbcr2rgb(img):
248
+ """Convert a YCbCr image to RGB image.
249
+
250
+ This function produces the same results as Matlab's ycbcr2rgb function.
251
+ It implements the ITU-R BT.601 conversion for standard-definition
252
+ television. See more details in
253
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
254
+
255
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
256
+ In OpenCV, it implements a JPEG conversion. See more details in
257
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
258
+
259
+ Args:
260
+ img (ndarray): The input image. It accepts:
261
+ 1. np.uint8 type with range [0, 255];
262
+ 2. np.float32 type with range [0, 1].
263
+
264
+ Returns:
265
+ ndarray: The converted RGB image. The output image has the same type
266
+ and range as input image.
267
+ """
268
+ img_type = img.dtype
269
+ img = _convert_input_type_range(img) * 255
270
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
271
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126
272
+ out_img = _convert_output_type_range(out_img, img_type)
273
+ return out_img
274
+
275
+
276
+ def ycbcr2bgr(img):
277
+ """Convert a YCbCr image to BGR image.
278
+
279
+ The bgr version of ycbcr2rgb.
280
+ It implements the ITU-R BT.601 conversion for standard-definition
281
+ television. See more details in
282
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
283
+
284
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
285
+ In OpenCV, it implements a JPEG conversion. See more details in
286
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
287
+
288
+ Args:
289
+ img (ndarray): The input image. It accepts:
290
+ 1. np.uint8 type with range [0, 255];
291
+ 2. np.float32 type with range [0, 1].
292
+
293
+ Returns:
294
+ ndarray: The converted BGR image. The output image has the same type
295
+ and range as input image.
296
+ """
297
+ img_type = img.dtype
298
+ img = _convert_input_type_range(img) * 255
299
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0],
300
+ [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126
301
+ out_img = _convert_output_type_range(out_img, img_type)
302
+ return out_img
303
+
304
+
305
+ def _convert_input_type_range(img):
306
+ """Convert the type and range of the input image.
307
+
308
+ It converts the input image to np.float32 type and range of [0, 1].
309
+ It is mainly used for pre-processing the input image in colorspace
310
+ conversion functions such as rgb2ycbcr and ycbcr2rgb.
311
+
312
+ Args:
313
+ img (ndarray): The input image. It accepts:
314
+ 1. np.uint8 type with range [0, 255];
315
+ 2. np.float32 type with range [0, 1].
316
+
317
+ Returns:
318
+ (ndarray): The converted image with type of np.float32 and range of
319
+ [0, 1].
320
+ """
321
+ img_type = img.dtype
322
+ img = img.astype(np.float32)
323
+ if img_type == np.float32:
324
+ pass
325
+ elif img_type == np.uint8:
326
+ img /= 255.
327
+ else:
328
+ raise TypeError(f'The img type should be np.float32 or np.uint8, but got {img_type}')
329
+ return img
330
+
331
+
332
+ def _convert_output_type_range(img, dst_type):
333
+ """Convert the type and range of the image according to dst_type.
334
+
335
+ It converts the image to desired type and range. If `dst_type` is np.uint8,
336
+ images will be converted to np.uint8 type with range [0, 255]. If
337
+ `dst_type` is np.float32, it converts the image to np.float32 type with
338
+ range [0, 1].
339
+ It is mainly used for post-processing images in colorspace conversion
340
+ functions such as rgb2ycbcr and ycbcr2rgb.
341
+
342
+ Args:
343
+ img (ndarray): The image to be converted with np.float32 type and
344
+ range [0, 255].
345
+ dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
346
+ converts the image to np.uint8 type with range [0, 255]. If
347
+ dst_type is np.float32, it converts the image to np.float32 type
348
+ with range [0, 1].
349
+
350
+ Returns:
351
+ (ndarray): The converted image with desired type and range.
352
+ """
353
+ if dst_type not in (np.uint8, np.float32):
354
+ raise TypeError(f'The dst_type should be np.float32 or np.uint8, but got {dst_type}')
355
+ if dst_type == np.uint8:
356
+ img = img.round()
357
+ else:
358
+ img /= 255.
359
+ return img.astype(dst_type)